Home | History | Annotate | Download | only in service
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
     18 
     19 #include <cmath>
     20 #include <type_traits>
     21 
     22 #include "absl/algorithm/container.h"
     23 #include "absl/base/casts.h"
     24 #include "absl/container/inlined_vector.h"
     25 #include "absl/memory/memory.h"
     26 #include "absl/meta/type_traits.h"
     27 #include "absl/types/optional.h"
     28 #include "tensorflow/compiler/xla/array2d.h"
     29 #include "tensorflow/compiler/xla/literal_util.h"
     30 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
     31 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
     32 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
     33 #include "tensorflow/compiler/xla/service/shape_inference.h"
     34 
     35 namespace xla {
     36 
     37 // TODO(b/79274244): We'd like these type traits to live inside of
     38 // HloEvaluatorTypedVisitor so they don't pollute namespace xla, but that
     39 // crashes clang in the frontend.
     40 //
     41 // Anyway this is relatively safe as-is because hlo_evaluator_typed_visitor.h is
     42 // a "private" header that's not exposed outside of hlo_evaluator.cc.
     43 template <typename T>
     44 using is_complex_t =
     45     absl::disjunction<std::is_same<T, complex64>, std::is_same<T, complex128>>;
     46 
     47 // ToArithmeticSafeType(T t):
     48 //  - converts `t` to the bitwise-equivalent `unsigned T` if T is a signed
     49 //    integer, and
     50 //  - otherwise returns `t` unchanged.
     51 //
     52 // It's UB in C++ to under/overflow a signed integer, so we wrap all arithmetic
     53 // in this type to force 2's complement behavior.
     54 template <typename T,
     55           typename std::enable_if<std::is_integral<T>::value &&
     56                                   std::is_signed<T>::value>::type* = nullptr>
     57 typename std::make_unsigned<T>::type ToArithmeticSafeType(T t) {
     58   return static_cast<typename std::make_unsigned<T>::type>(t);
     59 }
     60 template <typename T,
     61           typename std::enable_if<!std::is_integral<T>::value ||
     62                                   !std::is_signed<T>::value>::type* = nullptr>
     63 T ToArithmeticSafeType(T t) {
     64   return std::move(t);
     65 }
     66 
     67 // Templated DfsHloVisitor for use by HloEvaluator.
     68 //
     69 // Typically ReturnT here indicates the resulting literal type of each evaluated
     70 // Handle* method of a TypedVisitor.  There are however a few notable exceptions
     71 // to this rule, notably:
     72 // - HandleCompare and HandleIsFinite: where the resulting literal type is
     73 //   always boolean.
     74 // - HandleImag and HandleReal: where the resulting literal type is always float
     75 //   and the operand is always complex, or real in the case of HandleReal.
     76 // These operations are handled outside of the parent HloEvaluator handlers
     77 // instead of from within TypedVisitor.
     78 //
     79 // Type params:
     80 //   - ReturnT: The type of input and output of each operation.
     81 //   - ElementwiseT: The type in which internal computation are done.
     82 //
     83 // This a logically a private part of HloEvaluator.  It lives in this header
     84 // file rather than in hlo_evaluator.cc because we use extern templates and a
     85 // bunch of independent cc files to speed up compiling the many instantiations
     86 // of this class.
     87 template <typename ReturnT, typename ElementwiseT = ReturnT>
     88 class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
     89  private:
     90   Status UnsupportedTypeError(HloInstruction* instruction) {
     91     return InvalidArgument(
     92         "Unsupported type for %s: %s", HloOpcodeString(instruction->opcode()),
     93         PrimitiveType_Name(instruction->shape().element_type()));
     94   }
     95 
     96   // Get the value in the given literal static_cast as a double.
     97   template <
     98       typename NativeT,
     99       typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
    100   double GetAsDouble(const Literal& literal,
    101                      absl::Span<const int64> input_index) {
    102     return static_cast<double>(literal.Get<NativeT>(input_index));
    103   }
    104 
    105   // Specialization for complex types. In this case it is not possible to
    106   // static_cast value to a double so just CHECK fail. This method is not used
    107   // at run-time, but must be available at compile-time to keep the compiler
    108   // happy.
    109   template <
    110       typename NativeT,
    111       typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
    112   double GetAsDouble(const Literal& literal,
    113                      absl::Span<const int64> input_index) {
    114     LOG(FATAL) << "Trying to get complex literal as double: "
    115                << literal.ToString();
    116   }
    117 
    118  public:
    119   explicit HloEvaluatorTypedVisitor(HloEvaluator* p) : parent_(p) {}
    120 
    121   // The following higher-order functions convert a function with ElementwiseT
    122   // to a function with ReturnT.
    123   std::function<ReturnT(ReturnT)> ConvertUnaryFunction(
    124       const std::function<ElementwiseT(ElementwiseT)>& unary_op) {
    125     return [&unary_op](ReturnT arg) {
    126       return static_cast<ReturnT>(unary_op(static_cast<ElementwiseT>(arg)));
    127     };
    128   }
    129   std::function<ReturnT(ReturnT, ReturnT)> ConvertBinaryFunction(
    130       const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>&
    131           binary_op) {
    132     return [&binary_op](ReturnT arg1, ReturnT arg2) {
    133       return static_cast<ReturnT>(binary_op(static_cast<ElementwiseT>(arg1),
    134                                             static_cast<ElementwiseT>(arg2)));
    135     };
    136   }
    137   std::function<ReturnT(ReturnT, ReturnT, ReturnT)> ConvertTernaryFunction(
    138       const std::function<ElementwiseT(ElementwiseT, ElementwiseT,
    139                                        ElementwiseT)>& ternary_op) {
    140     return [&ternary_op](ReturnT arg1, ReturnT arg2, ReturnT arg3) {
    141       return static_cast<ReturnT>(ternary_op(static_cast<ElementwiseT>(arg1),
    142                                              static_cast<ElementwiseT>(arg2),
    143                                              static_cast<ElementwiseT>(arg3)));
    144     };
    145   }
    146 
    147   Status DefaultAction(HloInstruction* hlo_instruction) override {
    148     return Unimplemented("unhandled HLO ops for HloEvaluator: %s.",
    149                          HloOpcodeString(hlo_instruction->opcode()));
    150   }
    151 
    152   template <typename NativeT,
    153             typename std::enable_if<std::is_unsigned<NativeT>::value>::type* =
    154                 nullptr>
    155   Status HandleAbs(HloInstruction* abs) {
    156     TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs],
    157                         ElementWiseUnaryOp(abs, [](NativeT elem_operand) {
    158                           return elem_operand;
    159                         }));
    160     return Status::OK();
    161   }
    162 
    163   template <
    164       typename NativeT,
    165       typename std::enable_if<std::is_signed<NativeT>::value>::type* = nullptr>
    166   Status HandleAbs(HloInstruction* abs) {
    167     TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs],
    168                         ElementWiseUnaryOp(abs, [](NativeT elem_operand) {
    169                           return std::abs(elem_operand);
    170                         }));
    171     return Status::OK();
    172   }
    173 
    174   template <
    175       typename NativeT,
    176       typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
    177   Status HandleAbs(HloInstruction* abs) {
    178     const Literal& operand_literal =
    179         parent_->GetEvaluatedLiteralFor(abs->operand(0));
    180     TF_ASSIGN_OR_RETURN(
    181         parent_->evaluated_[abs],
    182         (HloEvaluator::ElementWiseUnaryOpImpl<float, NativeT>(
    183             abs, [](NativeT elem_operand) { return std::abs(elem_operand); },
    184             operand_literal)));
    185 
    186     return Status::OK();
    187   }
    188 
    189   Status HandleAbs(HloInstruction* abs) override {
    190     // If the operand is of C64 type, the return type of abs will be F32.
    191     // However, ElementwiseT would still be the return type, F32, and thus
    192     // specifying the ElementwiseT explicitly as C64 is needed below.
    193     if (abs->operand(0)->shape().element_type() == C64) {
    194       return HandleAbs<complex64>(abs);
    195     } else if (abs->operand(0)->shape().element_type() == C128) {
    196       return HandleAbs<complex128>(abs);
    197     }
    198     return HandleAbs<ElementwiseT>(abs);
    199   }
    200 
    201   template <
    202       typename NativeT,
    203       typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
    204   Status HandleRound(HloInstruction* round) {
    205     TF_ASSIGN_OR_RETURN(
    206         parent_->evaluated_[round],
    207         ElementWiseUnaryOp(round, [](ElementwiseT elem_operand) {
    208           return std::round(elem_operand);
    209         }));
    210     return Status::OK();
    211   }
    212 
    213   template <
    214       typename NativeT,
    215       typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
    216   Status HandleRound(HloInstruction* round) {
    217     return UnsupportedTypeError(round);
    218   }
    219 
    220   Status HandleRound(HloInstruction* round) override {
    221     return HandleRound<ReturnT>(round);
    222   }
    223 
    224   template <
    225       typename NativeT,
    226       typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
    227   Status HandleCeil(HloInstruction* ceil) {
    228     TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil],
    229                         ElementWiseUnaryOp(ceil, [](ElementwiseT elem_operand) {
    230                           return std::ceil(elem_operand);
    231                         }));
    232     return Status::OK();
    233   }
    234 
    235   template <
    236       typename NativeT,
    237       typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
    238   Status HandleCeil(HloInstruction* ceil) {
    239     return UnsupportedTypeError(ceil);
    240   }
    241 
    242   Status HandleCeil(HloInstruction* ceil) override {
    243     return HandleCeil<ReturnT>(ceil);
    244   }
    245 
    246   Status HandleConvert(HloInstruction* convert) override {
    247     const HloInstruction* operand = convert->operand(0);
    248     TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
    249     TF_ASSIGN_OR_RETURN(Literal result,
    250                         parent_->GetEvaluatedLiteralFor(operand).Convert(
    251                             convert->shape().element_type()));
    252     parent_->evaluated_[convert] = std::move(result);
    253     return Status::OK();
    254   }
    255 
    256   Status HandleBitcastConvert(HloInstruction* convert) override {
    257     const HloInstruction* operand = convert->operand(0);
    258     TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
    259     TF_ASSIGN_OR_RETURN(Literal result,
    260                         parent_->GetEvaluatedLiteralFor(operand).BitcastConvert(
    261                             convert->shape().element_type()));
    262 
    263     parent_->evaluated_[convert] = std::move(result);
    264     return Status::OK();
    265   }
    266 
    267   Status HandleExp(HloInstruction* exp) override {
    268     TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp],
    269                         ElementWiseUnaryOp(exp, [](ElementwiseT elem_operand) {
    270                           return std::exp(elem_operand);
    271                         }));
    272     return Status::OK();
    273   }
    274 
    275   template <
    276       typename NativeT,
    277       typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
    278   Status HandleExpm1(HloInstruction* expm1) {
    279     TF_ASSIGN_OR_RETURN(
    280         parent_->evaluated_[expm1],
    281         ElementWiseUnaryOp(expm1, [](ElementwiseT elem_operand) {
    282           return std::expm1(elem_operand);
    283         }));
    284     return Status::OK();
    285   }
    286 
    287   template <
    288       typename NativeT,
    289       typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
    290   Status HandleExpm1(HloInstruction* expm1) {
    291     return UnsupportedTypeError(expm1);
    292   }
    293 
    294   Status HandleExpm1(HloInstruction* floor) override {
    295     return HandleExpm1<ReturnT>(floor);
    296   }
    297 
    298   template <
    299       typename NativeT,
    300       typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
    301   Status HandleFloor(HloInstruction* floor) {
    302     TF_ASSIGN_OR_RETURN(
    303         parent_->evaluated_[floor],
    304         ElementWiseUnaryOp(floor, [](ElementwiseT elem_operand) {
    305           return std::floor(elem_operand);
    306         }));
    307     return Status::OK();
    308   }
    309 
    310   template <
    311       typename NativeT,
    312       typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
    313   Status HandleFloor(HloInstruction* floor) {
    314     return UnsupportedTypeError(floor);
    315   }
    316 
    317   Status HandleFloor(HloInstruction* floor) override {
    318     return HandleFloor<ReturnT>(floor);
    319   }
    320 
    321   Status HandleLog(HloInstruction* log) override {
    322     TF_ASSIGN_OR_RETURN(parent_->evaluated_[log],
    323                         ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) {
    324                           return std::log(elem_operand);
    325                         }));
    326     return Status::OK();
    327   }
    328 
    329   template <
    330       typename NativeT,
    331       typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
    332   Status HandleLog1p(HloInstruction* log1p) {
    333     TF_ASSIGN_OR_RETURN(
    334         parent_->evaluated_[log1p],
    335         ElementWiseUnaryOp(log1p, [](ElementwiseT elem_operand) {
    336           return std::log1p(elem_operand);
    337         }));
    338     return Status::OK();
    339   }
    340 
    341   template <
    342       typename NativeT,
    343       typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
    344   Status HandleLog1p(HloInstruction* log1p) {
    345     return UnsupportedTypeError(log1p);
    346   }
    347 
    348   Status HandleLog1p(HloInstruction* log1p) override {
    349     return HandleLog1p<ReturnT>(log1p);
    350   }
    351 
    352   template <typename NativeT,
    353             typename std::enable_if<
    354                 std::is_integral<NativeT>::value &&
    355                 !std::is_same<NativeT, bool>::value>::type* = nullptr>
    356   Status HandleNot(HloInstruction* not_) {
    357     TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_],
    358                         ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) {
    359                           return ~elem_operand;
    360                         }));
    361     return Status::OK();
    362   }
    363 
    364   template <typename NativeT, typename std::enable_if<std::is_floating_point<
    365                                   NativeT>::value>::type* = nullptr>
    366   Status HandleNot(HloInstruction* not_) {
    367     TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_],
    368                         ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) {
    369                           return !elem_operand;
    370                         }));
    371     return Status::OK();
    372   }
    373 
    374   template <typename NativeT,
    375             typename std::enable_if<std::is_same<NativeT, bool>::value>::type* =
    376                 nullptr>
    377   Status HandleNot(HloInstruction* not_) {
    378     TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_],
    379                         ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) {
    380                           return !elem_operand;
    381                         }));
    382     return Status::OK();
    383   }
    384 
    385   template <
    386       typename NativeT,
    387       typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
    388   Status HandleNot(HloInstruction* not_) {
    389     return UnsupportedTypeError(not_);
    390   }
    391 
    392   Status HandleNot(HloInstruction* not_) override {
    393     return HandleNot<ElementwiseT>(not_);
    394   }
    395 
    396   template <typename NativeT,
    397             typename std::enable_if<
    398                 std::is_signed<NativeT>::value &&
    399                 !std::is_floating_point<NativeT>::value>::type* = nullptr>
    400   Status HandleNegate(HloInstruction* negate) {
    401     using type = typename std::make_unsigned<NativeT>::type;
    402     TF_ASSIGN_OR_RETURN(
    403         parent_->evaluated_[negate],
    404         ElementWiseUnaryOp(negate, [](ElementwiseT elem_operand) {
    405           return NativeT(-type(elem_operand));
    406         }));
    407     return Status::OK();
    408   }
    409 
    410   template <typename NativeT,
    411             typename std::enable_if<
    412                 !std::is_signed<NativeT>::value ||
    413                 std::is_floating_point<NativeT>::value>::type* = nullptr>
    414   Status HandleNegate(HloInstruction* negate) {
    415     TF_ASSIGN_OR_RETURN(
    416         parent_->evaluated_[negate],
    417         ElementWiseUnaryOp(
    418             negate, [](ElementwiseT elem_operand) { return -elem_operand; }));
    419     return Status::OK();
    420   }
    421 
    422   Status HandleNegate(HloInstruction* negate) override {
    423     return HandleNegate<ReturnT>(negate);
    424   }
    425 
    426   template <typename NativeT,
    427             typename std::enable_if<std::is_integral<NativeT>::value>::type* =
    428                 nullptr>
    429   Status HandleSign(HloInstruction* sign) {
    430     TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign],
    431                         ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) {
    432                           return (ElementwiseT(0) < elem_operand) -
    433                                  (elem_operand < ElementwiseT(0));
    434                         }));
    435     return Status::OK();
    436   }
    437 
    438   template <typename NativeT,
    439             typename std::enable_if<
    440                 std::is_same<NativeT, bfloat16>::value ||
    441                 std::is_same<NativeT, Eigen::half>::value ||
    442                 std::is_floating_point<NativeT>::value>::type* = nullptr>
    443   Status HandleSign(HloInstruction* sign) {
    444     TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign],
    445                         ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) {
    446                           return std::isnan(elem_operand)
    447                                      ? elem_operand
    448                                      : std::copysign(
    449                                            elem_operand != ElementwiseT(0),
    450                                            elem_operand);
    451                         }));
    452     return Status::OK();
    453   }
    454 
    455   template <
    456       typename NativeT,
    457       typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
    458   Status HandleSign(HloInstruction* sign) {
    459     TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign],
    460                         ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) {
    461                           auto abs_val = std::abs(elem_operand);
    462                           return 0 == abs_val ? ElementwiseT(0)
    463                                               : elem_operand / abs_val;
    464                         }));
    465     return Status::OK();
    466   }
    467 
    468   Status HandleSign(HloInstruction* sign) override {
    469     return HandleSign<ReturnT>(sign);
    470   }
    471 
    472   template <typename NativeT, typename std::enable_if<std::is_floating_point<
    473                                   NativeT>::value>::type* = nullptr>
    474   Status HandleAtan2(HloInstruction* atan2) {
    475     TF_ASSIGN_OR_RETURN(parent_->evaluated_[atan2],
    476                         ElementWiseBinaryOp(atan2, [](ElementwiseT lhs_elem,
    477                                                       ElementwiseT rhs_elem) {
    478                           return std::atan2(lhs_elem, rhs_elem);
    479                         }));
    480     return Status::OK();
    481   }
    482 
    483   template <typename NativeT, typename std::enable_if<!std::is_floating_point<
    484                                   NativeT>::value>::type* = nullptr>
    485   Status HandleAtan2(HloInstruction* atan2) {
    486     return UnsupportedTypeError(atan2);
    487   }
    488 
    489   Status HandleAtan2(HloInstruction* atan2) override {
    490     return HandleAtan2<ElementwiseT>(atan2);
    491   }
    492 
    493   Status HandleTanh(HloInstruction* tanh) override {
    494     TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh],
    495                         ElementWiseUnaryOp(tanh, [](ElementwiseT elem_operand) {
    496                           return std::tanh(elem_operand);
    497                         }));
    498     return Status::OK();
    499   }
    500 
    501   Status HandleMultiply(HloInstruction* multiply) override {
    502     TF_ASSIGN_OR_RETURN(
    503         parent_->evaluated_[multiply],
    504         ElementWiseBinaryOp(
    505             multiply, [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) {
    506               return ElementwiseT(ToArithmeticSafeType(lhs_elem) *
    507                                   ToArithmeticSafeType(rhs_elem));
    508             }));
    509     return Status::OK();
    510   }
    511 
    512   Status HandleSubtract(HloInstruction* subtract) override {
    513     TF_ASSIGN_OR_RETURN(
    514         parent_->evaluated_[subtract],
    515         ElementWiseBinaryOp(
    516             subtract, [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) {
    517               return ElementwiseT(ToArithmeticSafeType(lhs_elem) -
    518                                   ToArithmeticSafeType(rhs_elem));
    519             }));
    520     return Status::OK();
    521   }
    522 
    523   Status HandleAdd(HloInstruction* add) override {
    524     TF_ASSIGN_OR_RETURN(parent_->evaluated_[add],
    525                         ElementWiseBinaryOp(add, [](ElementwiseT lhs_elem,
    526                                                     ElementwiseT rhs_elem) {
    527                           return ElementwiseT(ToArithmeticSafeType(lhs_elem) +
    528                                               ToArithmeticSafeType(rhs_elem));
    529                         }));
    530     return Status::OK();
    531   }
    532 
    533   template <
    534       typename NativeT,
    535       typename std::enable_if<std::is_floating_point<NativeT>::value ||
    536                               is_complex_t<NativeT>::value>::type* = nullptr>
    537   Status HandleDivide(HloInstruction* divide) {
    538     TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide],
    539                         ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem,
    540                                                        ElementwiseT rhs_elem) {
    541                           return lhs_elem / rhs_elem;
    542                         }));
    543     return Status::OK();
    544   }
    545 
    546   template <typename NativeT,
    547             typename std::enable_if<std::is_signed<NativeT>::value &&
    548                                     std::is_integral<NativeT>::value>::type* =
    549                 nullptr>
    550   Status HandleDivide(HloInstruction* divide) {
    551     TF_ASSIGN_OR_RETURN(
    552         parent_->evaluated_[divide],
    553         ElementWiseBinaryOp(
    554             divide,
    555             [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) -> ElementwiseT {
    556               if (rhs_elem == 0) {
    557                 return static_cast<ElementwiseT>(-1);
    558               }
    559               if (rhs_elem == -1 &&
    560                   lhs_elem == std::numeric_limits<ElementwiseT>::min()) {
    561                 return lhs_elem;
    562               }
    563               return lhs_elem / rhs_elem;
    564             }));
    565     return Status::OK();
    566   }
    567 
    568   template <typename NativeT,
    569             typename std::enable_if<std::is_unsigned<NativeT>::value>::type* =
    570                 nullptr>
    571   Status HandleDivide(HloInstruction* divide) {
    572     TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide],
    573                         ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem,
    574                                                        ElementwiseT rhs_elem) {
    575                           return rhs_elem == 0
    576                                      ? std::numeric_limits<ElementwiseT>::max()
    577                                      : (lhs_elem / rhs_elem);
    578                         }));
    579     return Status::OK();
    580   }
    581 
    582   Status HandleDivide(HloInstruction* divide) override {
    583     return HandleDivide<ElementwiseT>(divide);
    584   }
    585 
    586   template <typename NativeT,
    587             typename std::enable_if<std::is_integral<NativeT>::value>::type* =
    588                 nullptr>
    589   Status HandleMaximum(HloInstruction* maximum) {
    590     TF_ASSIGN_OR_RETURN(
    591         parent_->evaluated_[maximum],
    592         ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) {
    593           return std::max(lhs, rhs);
    594         }));
    595     return Status::OK();
    596   }
    597 
    598   template <typename NativeT, typename std::enable_if<std::is_floating_point<
    599                                   NativeT>::value>::type* = nullptr>
    600   Status HandleMaximum(HloInstruction* maximum) {
    601     TF_ASSIGN_OR_RETURN(
    602         parent_->evaluated_[maximum],
    603         ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) {
    604           return ((lhs >= rhs) || std::isnan(lhs)) ? lhs : rhs;
    605         }));
    606     return Status::OK();
    607   }
    608 
    609   template <
    610       typename NativeT,
    611       typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
    612   Status HandleMaximum(HloInstruction* maximum) {
    613     return UnsupportedTypeError(maximum);
    614   }
    615 
    616   Status HandleMaximum(HloInstruction* maximum) override {
    617     return HandleMaximum<ElementwiseT>(maximum);
    618   }
    619 
    620   template <typename NativeT,
    621             typename std::enable_if<std::is_integral<NativeT>::value>::type* =
    622                 nullptr>
    623   Status HandleMinimum(HloInstruction* minimum) {
    624     TF_ASSIGN_OR_RETURN(parent_->evaluated_[minimum],
    625                         ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el,
    626                                                         ElementwiseT rhs_el) {
    627                           return std::min(lhs_el, rhs_el);
    628                         }));
    629     return Status::OK();
    630   }
    631 
    632   template <typename NativeT, typename std::enable_if<std::is_floating_point<
    633                                   NativeT>::value>::type* = nullptr>
    634   Status HandleMinimum(HloInstruction* minimum) {
    635     TF_ASSIGN_OR_RETURN(
    636         parent_->evaluated_[minimum],
    637         ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el,
    638                                         ElementwiseT rhs_el) {
    639           return ((lhs_el <= rhs_el) || std::isnan(lhs_el)) ? lhs_el : rhs_el;
    640         }));
    641     return Status::OK();
    642   }
    643 
    644   template <
    645       typename NativeT,
    646       typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
    647   Status HandleMinimum(HloInstruction* minimum) {
    648     return UnsupportedTypeError(minimum);
    649   }
    650 
    651   Status HandleMinimum(HloInstruction* minimum) override {
    652     return HandleMinimum<ElementwiseT>(minimum);
    653   }
    654 
    655   Status HandlePower(HloInstruction* power) override {
    656     TF_ASSIGN_OR_RETURN(
    657         parent_->evaluated_[power],
    658         ElementWiseBinaryOp(
    659             power, [](ElementwiseT lhs_el, ElementwiseT rhs_el) {
    660               return lhs_el == ElementwiseT(0) && rhs_el == ElementwiseT(0)
    661                          ? static_cast<ElementwiseT>(1)
    662                          : std::pow(lhs_el, rhs_el);
    663             }));
    664     return Status::OK();
    665   }
    666 
    667   Status HandleSqrt(HloInstruction* sqrt) override {
    668     TF_ASSIGN_OR_RETURN(parent_->evaluated_[sqrt],
    669                         ElementWiseUnaryOp(sqrt, [](ElementwiseT elem_operand) {
    670                           return std::sqrt(elem_operand);
    671                         }));
    672     return Status::OK();
    673   }
    674 
    675   Status HandleRsqrt(HloInstruction* rsqrt) override {
    676     TF_ASSIGN_OR_RETURN(
    677         parent_->evaluated_[rsqrt],
    678         ElementWiseUnaryOp(rsqrt, [](ElementwiseT elem_operand) {
    679           return static_cast<ElementwiseT>(1) / std::sqrt(elem_operand);
    680         }));
    681     return Status::OK();
    682   }
    683 
    684   template <typename NativeT, typename std::enable_if<std::is_floating_point<
    685                                   NativeT>::value>::type* = nullptr>
    686   Status HandleRemainder(HloInstruction* remainder) {
    687     TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder],
    688                         ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el,
    689                                                           ElementwiseT rhs_el) {
    690                           return std::fmod(lhs_el, rhs_el);
    691                         }));
    692     return Status::OK();
    693   }
    694 
    695   template <typename NativeT,
    696             typename std::enable_if<std::is_unsigned<NativeT>::value>::type* =
    697                 nullptr>
    698   Status HandleRemainder(HloInstruction* remainder) {
    699     TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder],
    700                         ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el,
    701                                                           ElementwiseT rhs_el) {
    702                           return rhs_el == 0 ? lhs_el : (lhs_el % rhs_el);
    703                         }));
    704     return Status::OK();
    705   }
    706 
    707   template <typename NativeT,
    708             typename std::enable_if<std::is_signed<NativeT>::value &&
    709                                     std::is_integral<NativeT>::value>::type* =
    710                 nullptr>
    711   Status HandleRemainder(HloInstruction* remainder) {
    712     TF_ASSIGN_OR_RETURN(
    713         parent_->evaluated_[remainder],
    714         ElementWiseBinaryOp(
    715             remainder,
    716             [](ElementwiseT lhs_el, ElementwiseT rhs_el) -> ElementwiseT {
    717               if (rhs_el == 0) {
    718                 return lhs_el;
    719               }
    720               if (rhs_el == -1 &&
    721                   lhs_el == std::numeric_limits<ElementwiseT>::min()) {
    722                 return 0;
    723               }
    724               return lhs_el % rhs_el;
    725             }));
    726     return Status::OK();
    727   }
    728 
    729   template <
    730       typename NativeT,
    731       typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
    732   Status HandleRemainder(HloInstruction* remainder) {
    733     return UnsupportedTypeError(remainder);
    734   }
    735 
    736   Status HandleRemainder(HloInstruction* remainder) override {
    737     return HandleRemainder<ElementwiseT>(remainder);
    738   }
    739 
    740   template <typename NativeT,
    741             typename std::enable_if<std::is_integral<NativeT>::value>::type* =
    742                 nullptr>
    743   Status HandleAnd(HloInstruction* and_) {
    744     TF_ASSIGN_OR_RETURN(
    745         parent_->evaluated_[and_],
    746         ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) {
    747           return lhs_el & rhs_el;
    748         }));
    749     return Status::OK();
    750   }
    751 
    752   template <typename NativeT, typename std::enable_if<std::is_floating_point<
    753                                   NativeT>::value>::type* = nullptr>
    754   Status HandleAnd(HloInstruction* and_) {
    755     return UnsupportedTypeError(and_);
    756   }
    757 
    758   template <
    759       typename NativeT,
    760       typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
    761   Status HandleAnd(HloInstruction* and_) {
    762     return UnsupportedTypeError(and_);
    763   }
    764 
    765   Status HandleAnd(HloInstruction* and_) override {
    766     return HandleAnd<ElementwiseT>(and_);
    767   }
    768 
    769   template <typename NativeT,
    770             typename std::enable_if<std::is_integral<NativeT>::value>::type* =
    771                 nullptr>
    772   Status HandleOr(HloInstruction* or_) {
    773     TF_ASSIGN_OR_RETURN(
    774         parent_->evaluated_[or_],
    775         ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) {
    776           return lhs_el | rhs_el;
    777         }));
    778     return Status::OK();
    779   }
    780 
    781   template <typename NativeT, typename std::enable_if<std::is_floating_point<
    782                                   NativeT>::value>::type* = nullptr>
    783   Status HandleOr(HloInstruction* or_) {
    784     return UnsupportedTypeError(or_);
    785   }
    786 
    787   template <
    788       typename NativeT,
    789       typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
    790   Status HandleOr(HloInstruction* or_) {
    791     return InvalidArgument("Unsupported type for Or");
    792   }
    793 
    794   Status HandleOr(HloInstruction* or_) override {
    795     return HandleOr<ElementwiseT>(or_);
    796   }
    797 
    798   template <typename NativeT,
    799             typename std::enable_if<std::is_integral<NativeT>::value>::type* =
    800                 nullptr>
    801   Status HandleXor(HloInstruction* xor_) {
    802     TF_ASSIGN_OR_RETURN(
    803         parent_->evaluated_[xor_],
    804         ElementWiseBinaryOp(xor_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) {
    805           return lhs_el ^ rhs_el;
    806         }));
    807     return Status::OK();
    808   }
    809 
    810   template <typename NativeT, typename std::enable_if<std::is_floating_point<
    811                                   NativeT>::value>::type* = nullptr>
    812   Status HandleXor(HloInstruction* xor_) {
    813     return UnsupportedTypeError(xor_);
    814   }
    815 
    816   template <
    817       typename NativeT,
    818       typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
    819   Status HandleXor(HloInstruction* xor_) {
    820     return UnsupportedTypeError(xor_);
    821   }
    822 
    823   Status HandleXor(HloInstruction* xor_) override {
    824     return HandleXor<ElementwiseT>(xor_);
    825   }
    826 
    827   template <typename NativeT,
    828             typename std::enable_if<
    829                 std::is_integral<NativeT>::value &&
    830                 !std::is_same<NativeT, bool>::value>::type* = nullptr>
    831   Status HandleShiftLeft(HloInstruction* shl) {
    832     TF_ASSIGN_OR_RETURN(
    833         parent_->evaluated_[shl],
    834         ElementWiseBinaryOp(shl, [](NativeT lhs_elem, NativeT rhs_elem) {
    835           return IsShiftOutOfBounds<NativeT>(rhs_elem) ? 0
    836                                                        : (lhs_elem << rhs_elem);
    837         }));
    838     return Status::OK();
    839   }
    840 
    841   template <typename NativeT,
    842             typename std::enable_if<!std::is_integral<NativeT>::value ||
    843                                     std::is_same<NativeT, bool>::value>::type* =
    844                 nullptr>
    845   Status HandleShiftLeft(HloInstruction* shift) {
    846     return UnsupportedTypeError(shift);
    847   }
    848 
    849   Status HandleShiftLeft(HloInstruction* shl) override {
    850     return HandleShiftLeft<ElementwiseT>(shl);
    851   }
    852   template <typename NativeT,
    853             typename std::enable_if<
    854                 std::is_integral<NativeT>::value &&
    855                 !std::is_same<NativeT, bool>::value>::type* = nullptr>
    856   Status HandleShiftRightArithmetic(HloInstruction* shr) {
    857     typedef typename std::make_signed<NativeT>::type SignedT;
    858     TF_ASSIGN_OR_RETURN(
    859         parent_->evaluated_[shr],
    860         ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) {
    861           SignedT lhs_signed = static_cast<SignedT>(lhs_elem);
    862           if (IsShiftOutOfBounds<NativeT>(rhs_elem)) {
    863             return lhs_signed < 0 ? static_cast<SignedT>(-1) : 0;
    864           } else {
    865             return lhs_signed >> rhs_elem;
    866           }
    867         }));
    868     return Status::OK();
    869   }
    870 
    871   template <typename NativeT,
    872             typename std::enable_if<!std::is_integral<NativeT>::value ||
    873                                     std::is_same<NativeT, bool>::value>::type* =
    874                 nullptr>
    875   Status HandleShiftRightArithmetic(HloInstruction* shift) {
    876     return UnsupportedTypeError(shift);
    877   }
    878 
    879   Status HandleShiftRightArithmetic(HloInstruction* shra) override {
    880     return HandleShiftRightArithmetic<ElementwiseT>(shra);
    881   }
    882 
    883   template <typename NativeT,
    884             typename std::enable_if<
    885                 std::is_integral<NativeT>::value &&
    886                 !std::is_same<NativeT, bool>::value>::type* = nullptr>
    887   Status HandleShiftRightLogical(HloInstruction* shr) {
    888     typedef typename std::make_unsigned<NativeT>::type UnsignedT;
    889     TF_ASSIGN_OR_RETURN(
    890         parent_->evaluated_[shr],
    891         ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) {
    892           // If shift amount is greater than the number of bits, then return 0.
    893           if (IsShiftOutOfBounds<NativeT>(rhs_elem)) {
    894             return static_cast<NativeT>(0);
    895           }
    896           return static_cast<NativeT>(static_cast<UnsignedT>(lhs_elem) >>
    897                                       rhs_elem);
    898         }));
    899     return Status::OK();
    900   }
    901 
    902   template <typename NativeT,
    903             typename std::enable_if<!std::is_integral<NativeT>::value ||
    904                                     std::is_same<NativeT, bool>::value>::type* =
    905                 nullptr>
    906   Status HandleShiftRightLogical(HloInstruction* shift) {
    907     return UnsupportedTypeError(shift);
    908   }
    909 
    910   Status HandleShiftRightLogical(HloInstruction* shrl) override {
    911     return HandleShiftRightLogical<ElementwiseT>(shrl);
    912   }
    913 
    914   // Special case for integral type due to MSVC's std::isnan being unable to
    915   // handle integral type.
    916   template <typename NativeT,
    917             typename std::enable_if<!is_complex_t<NativeT>::value &&
    918                                     std::is_integral<NativeT>::value>::type* =
    919                 nullptr>
    920   Status HandleClamp(HloInstruction* clamp) {
    921     std::function<ElementwiseT(ElementwiseT, ElementwiseT, ElementwiseT)>
    922         clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) {
    923           return static_cast<ElementwiseT>(
    924               std::min(high, std::max(value, low)));
    925         };
    926     TF_ASSIGN_OR_RETURN(
    927         parent_->evaluated_[clamp],
    928         ElementwiseTernaryOp(clamp,
    929                              std::move(ConvertTernaryFunction(clamp_op))));
    930     return Status::OK();
    931   }
    932 
    933   template <typename NativeT,
    934             typename std::enable_if<!is_complex_t<NativeT>::value &&
    935                                     !std::is_integral<NativeT>::value>::type* =
    936                 nullptr>
    937   Status HandleClamp(HloInstruction* clamp) {
    938     std::function<ElementwiseT(ElementwiseT, ElementwiseT, ElementwiseT)>
    939         clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) {
    940           if (std::isnan(low) || std::isnan(high)) {
    941             return static_cast<ElementwiseT>(NAN);
    942           }
    943           return static_cast<ElementwiseT>(
    944               std::min<NativeT>(high, std::max<NativeT>(value, low)));
    945         };
    946     TF_ASSIGN_OR_RETURN(
    947         parent_->evaluated_[clamp],
    948         ElementwiseTernaryOp(clamp,
    949                              std::move(ConvertTernaryFunction(clamp_op))));
    950     return Status::OK();
    951   }
    952 
    953   template <
    954       typename NativeT,
    955       typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
    956   Status HandleClamp(HloInstruction* clamp) {
    957     return UnsupportedTypeError(clamp);
    958   }
    959 
    960   Status HandleClamp(HloInstruction* clamp) override {
    961     return HandleClamp<ElementwiseT>(clamp);
    962   }
    963 
    964   Status HandleSelect(HloInstruction* select) override {
    965     CHECK(!ShapeUtil::IsScalar(select->operand(0)->shape()));
    966     CHECK(select->shape().IsArray());
    967     std::function<ReturnT(bool, ReturnT, ReturnT)> select_op =
    968         [](bool pred, ReturnT on_true, ReturnT on_false) {
    969           if (pred) {
    970             return on_true;
    971           }
    972           return on_false;
    973         };
    974     TF_ASSIGN_OR_RETURN(parent_->evaluated_[select],
    975                         ElementwiseTernaryOp(select, std::move(select_op)));
    976     return Status::OK();
    977   }
    978 
    979   Status HandleReverse(HloInstruction* reverse) override {
    980     const auto result_shape = reverse->shape();
    981     const auto reverse_dimensions = reverse->dimensions();
    982 
    983     auto operand = reverse->operand(0);
    984     TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
    985                         ShapeInference::InferReverseShape(operand->shape(),
    986                                                           reverse_dimensions));
    987 
    988     TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
    989         << "return shape set to: " << ShapeUtil::HumanString(result_shape)
    990         << " but is inferred to be: "
    991         << ShapeUtil::HumanString(inferred_return_shape);
    992 
    993     const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
    994     Literal result(result_shape);
    995 
    996     TF_RETURN_IF_ERROR(
    997         result.Populate<ReturnT>([&](absl::Span<const int64> out_index) {
    998           std::vector<int64> from_index(out_index.begin(), out_index.end());
    999           for (const int64 dim : reverse_dimensions) {
   1000             from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim];
   1001           }
   1002           return operand_literal.Get<ReturnT>(from_index);
   1003         }));
   1004 
   1005     parent_->evaluated_[reverse] = std::move(result);
   1006     return Status::OK();
   1007   }
   1008 
   1009   Status HandleConvolution(HloInstruction* conv) override {
   1010     auto lhs = conv->operand(0);
   1011     auto rhs = conv->operand(1);
   1012     const auto& window = conv->window();
   1013     const Shape& result_shape = conv->shape();
   1014     const Shape& lhs_shape = lhs->shape();
   1015     const Shape& rhs_shape = rhs->shape();
   1016 
   1017     TF_CHECK_OK(ShapeUtil::ValidateShape(lhs_shape));
   1018     TF_CHECK_OK(ShapeUtil::ValidateShape(rhs_shape));
   1019     CHECK(lhs_shape.IsArray());
   1020     CHECK(rhs_shape.IsArray());
   1021     CHECK(ShapeUtil::SameElementType(lhs_shape, rhs_shape));
   1022     CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape));
   1023 
   1024     const auto& dnums = conv->convolution_dimension_numbers();
   1025     const int64 num_spatial_dims = dnums.output_spatial_dimensions_size();
   1026     CHECK_EQ(num_spatial_dims, dnums.input_spatial_dimensions_size());
   1027     CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size());
   1028     CHECK_GE(num_spatial_dims, 0);
   1029     CHECK_EQ(window.dimensions_size(), num_spatial_dims);
   1030 
   1031     const auto lhs_rank = lhs_shape.rank();
   1032     const auto rhs_rank = rhs_shape.rank();
   1033 
   1034     CHECK_EQ(num_spatial_dims + 2, lhs_rank);
   1035     CHECK_EQ(num_spatial_dims + 2, rhs_rank);
   1036 
   1037     TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
   1038                         ShapeInference::InferConvolveShape(
   1039                             lhs_shape, rhs_shape, conv->feature_group_count(),
   1040                             conv->batch_group_count(), window, dnums));
   1041     CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
   1042         << "return shape set to: " << ShapeUtil::HumanString(result_shape)
   1043         << " but is inferred to be: "
   1044         << ShapeUtil::HumanString(inferred_return_shape);
   1045 
   1046     const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
   1047     const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
   1048 
   1049     std::vector<int64> window_dimension_sizes;
   1050     for (auto i : dnums.kernel_spatial_dimensions()) {
   1051       window_dimension_sizes.push_back(ShapeUtil::GetDimension(rhs_shape, i));
   1052     }
   1053 
   1054     const Shape& window_shape =
   1055         ShapeUtil::MakeShape(rhs_shape.element_type(), window_dimension_sizes);
   1056 
   1057     DimensionVector lhs_dim_multipliers = MakeDimMultipliers(lhs_shape);
   1058     DimensionVector rhs_dim_multipliers = MakeDimMultipliers(rhs_shape);
   1059 
   1060     auto lhs_literal_data = lhs_literal.data<ReturnT>();
   1061     auto rhs_literal_data = rhs_literal.data<ReturnT>();
   1062 
   1063     const int64 feature_group_count = conv->feature_group_count();
   1064     const int64 batch_group_count = conv->batch_group_count();
   1065 
   1066     auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window,
   1067                  &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data,
   1068                  rhs_literal_data, feature_group_count,
   1069                  batch_group_count](const absl::Span<const int64> out_index) {
   1070       // Dimension number applicable for input (lhs).
   1071       const int64 input_batch_dim = dnums.input_batch_dimension();
   1072       const int64 input_z_dim = dnums.input_feature_dimension();
   1073       // Dimension number applicable for kernel (rhs).
   1074       const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension();
   1075       const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension();
   1076       // Dimension number applicable for output.
   1077       const int64 output_batch_dim = dnums.output_batch_dimension();
   1078       const int64 output_z_dim = dnums.output_feature_dimension();
   1079 
   1080       const int64 input_z_size =
   1081           ShapeUtil::GetDimension(lhs_shape, input_z_dim);
   1082 
   1083       const int64 input_batch_size =
   1084           ShapeUtil::GetDimension(lhs_shape, input_batch_dim);
   1085 
   1086       const int64 batch_group_size = input_batch_size / batch_group_count;
   1087 
   1088       // The size of an input feature group.
   1089       const int64 input_feature_group_size = input_z_size / feature_group_count;
   1090 
   1091       const int64 output_z_size =
   1092           ShapeUtil::GetDimension(rhs_shape, kernel_output_z_dim);
   1093       // The output feature dimension is a concatenation of convolution results
   1094       // from the different groups.
   1095       const int64 output_feature_group_size =
   1096           output_z_size / feature_group_count;
   1097 
   1098       // Calculate the group index to which the current output index
   1099       // belongs.
   1100       const int64 feature_group_index =
   1101           out_index[output_z_dim] / output_feature_group_size;
   1102 
   1103       const int64 batch_group_index = out_index[output_z_dim];
   1104 
   1105       ElementwiseT result_val = static_cast<ElementwiseT>(0);
   1106       DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(),
   1107                                         0);
   1108 
   1109       // Convolve input feature with kernel.
   1110       // The mechanism indexes into the correct LHS (input) and RHS (kernel)
   1111       // locations and accumulates multiplications for a given output index.
   1112       do {
   1113         // Find corresponding spatial dimension index for input (lhs).
   1114         int64 lhs_linear_spatial_index = 0;
   1115         int64 rhs_linear_spatial_index = 0;
   1116         for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) {
   1117           // Spatial dimension number for input (lhs) and output.
   1118           const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki);
   1119           const int64 output_spatial_dim = dnums.output_spatial_dimensions(ki);
   1120 
   1121           // Calculate lhs (input) index without taking base dilation into
   1122           // account.
   1123           const auto& window_dim = window.dimensions(ki);
   1124           const int64 undilated_index =
   1125               out_index[output_spatial_dim] * window_dim.stride() -
   1126               window_dim.padding_low() +
   1127               rhs_spatial_index[ki] * window_dim.window_dilation();
   1128           // Skip if the lhs (input) index is to be dilated.  As an
   1129           // optimization, skip this mod if there's no dilation.
   1130           if (window_dim.base_dilation() > 1 &&
   1131               undilated_index % window_dim.base_dilation() != 0) {
   1132             goto cnt;
   1133           }
   1134 
   1135           // Calculate the actual lhs (input) index after dilation.  As an
   1136           // optimization, skip this integer divide if there's no dilation.
   1137           int64 lhs_spatial_index;
   1138           if (window_dim.base_dilation() > 1) {
   1139             lhs_spatial_index = undilated_index / window_dim.base_dilation();
   1140           } else {
   1141             lhs_spatial_index = undilated_index;
   1142           }
   1143 
   1144           // Skip if input index is not in bounds.
   1145           if (!(lhs_spatial_index >= 0 &&
   1146                 lhs_spatial_index < lhs_shape.dimensions(input_spatial_dim))) {
   1147             goto cnt;
   1148           }
   1149 
   1150           lhs_linear_spatial_index +=
   1151               lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim];
   1152           rhs_linear_spatial_index +=
   1153               (window_dim.window_reversal()
   1154                    ? ((window_dim.size() - 1) - rhs_spatial_index[ki])
   1155                    : rhs_spatial_index[ki]) *
   1156               rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)];
   1157         }
   1158 
   1159         for (int64 rhs_iz = 0; rhs_iz < input_feature_group_size; ++rhs_iz) {
   1160           const int64 iz =
   1161               feature_group_index * input_feature_group_size + rhs_iz;
   1162 
   1163           int64 lhs_linear_index = lhs_linear_spatial_index;
   1164 
   1165           lhs_linear_index += out_index[output_batch_dim] *
   1166                               lhs_dim_multipliers[input_batch_dim];
   1167 
   1168           // We are scraping only the diagonal elements in the resultant
   1169           // convolution output when batch_group_count is greater than 1,
   1170           // where 1 is the default. No scraping is done in that case.
   1171           // This approach works out automatically for 'groups' in batches
   1172           // with group_size > 1, because we already descend down the batch
   1173           // dimension for the 'output_batch_dim' above.
   1174           lhs_linear_index +=
   1175               ((batch_group_index * batch_group_size) % input_batch_size) *
   1176               lhs_dim_multipliers[input_batch_dim];
   1177 
   1178           lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim];
   1179 
   1180           int64 rhs_linear_index = rhs_linear_spatial_index;
   1181 
   1182           rhs_linear_index += out_index[output_z_dim] *
   1183                               rhs_dim_multipliers[kernel_output_z_dim];
   1184           rhs_linear_index += rhs_iz * rhs_dim_multipliers[kernel_input_z_dim];
   1185 
   1186           result_val +=
   1187               static_cast<ElementwiseT>(lhs_literal_data[lhs_linear_index]) *
   1188               static_cast<ElementwiseT>(rhs_literal_data[rhs_linear_index]);
   1189         }
   1190       cnt : {}
   1191       } while (IndexUtil::BumpIndices(window_shape,
   1192                                       absl::MakeSpan(rhs_spatial_index)));
   1193 
   1194       return static_cast<ReturnT>(result_val);
   1195     };
   1196 
   1197     Literal result(result_shape);
   1198     TF_RETURN_IF_ERROR(result.PopulateParallel<ReturnT>(func));
   1199 
   1200     parent_->evaluated_[conv] = std::move(result);
   1201     return Status::OK();
   1202   }
   1203 
   1204   Status HandleDot(HloInstruction* dot) override {
   1205     if (dot->dot_dimension_numbers().rhs_contracting_dimensions_size() == 1 &&
   1206         parent_->use_fast_path_) {
   1207       return HandleDot<ReturnT>(dot);
   1208     }
   1209     return HandleDotSlowPath(dot);
   1210   }
   1211 
   1212   template <typename NativeT, typename std::enable_if<std::is_same<
   1213                                   NativeT, float>::value>::type* = nullptr>
   1214   Status HandleDot(HloInstruction* dot) {
   1215     const HloInstruction* lhs = dot->operand(0);
   1216     const HloInstruction* rhs = dot->operand(1);
   1217     CHECK(dot->shape().IsArray());
   1218     CHECK(lhs->shape().IsArray());
   1219     CHECK(rhs->shape().IsArray());
   1220 
   1221     const auto& dnums = dot->dot_dimension_numbers();
   1222 
   1223     const int64 lhs_rank = lhs->shape().rank();
   1224     const int64 rhs_rank = rhs->shape().rank();
   1225 
   1226     CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape()));
   1227     CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape()));
   1228 
   1229     // There must be 1 and only 1 Contracting dimension for lhs and rhs.
   1230     const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0);
   1231     const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0);
   1232     // Contracted dimension sizes must be the same.
   1233     CHECK_EQ(lhs->shape().dimensions(lhs_contracting_dimension),
   1234              rhs->shape().dimensions(rhs_contracting_dimension))
   1235         << "lhs contracted dimension: "
   1236         << lhs->shape().dimensions(lhs_contracting_dimension)
   1237         << " rhs contracted dimension: "
   1238         << rhs->shape().dimensions(rhs_contracting_dimension);
   1239 
   1240     // The fast path is for a simple rank 2 dot with default layout operands.
   1241     if (lhs_rank == 2 && rhs_rank == 2 && lhs_contracting_dimension == 1 &&
   1242         rhs_contracting_dimension == 0 &&
   1243         LayoutUtil::Equal(lhs->shape().layout(),
   1244                           LayoutUtil::GetDefaultLayoutForR2()) &&
   1245         LayoutUtil::Equal(rhs->shape().layout(),
   1246                           LayoutUtil::GetDefaultLayoutForR2()) &&
   1247         LayoutUtil::Equal(dot->shape().layout(),
   1248                           LayoutUtil::GetDefaultLayoutForR2())) {
   1249       const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
   1250       const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
   1251       const int64 contracted_dimension_size =
   1252           lhs->shape().dimensions(lhs_contracting_dimension);
   1253       Array2D<NativeT> lhs_array(lhs->shape().dimensions(0),
   1254                                  contracted_dimension_size);
   1255       lhs_array.SetValues(lhs_literal.data<NativeT>());
   1256       Array2D<NativeT> rhs_array(contracted_dimension_size,
   1257                                  rhs->shape().dimensions(1));
   1258       rhs_array.SetValues(rhs_literal.data<NativeT>());
   1259       std::unique_ptr<Array2D<NativeT>> result_array =
   1260           HloEvaluator::MatmulArray2D(lhs_array, rhs_array);
   1261       Literal result(dot->shape());
   1262       result.PopulateR2FromArray2D(*result_array);
   1263       parent_->evaluated_[dot] = std::move(result);
   1264       return Status::OK();
   1265     }
   1266     return HandleDotSlowPath(dot);
   1267   }
   1268 
   1269   template <typename NativeT, typename std::enable_if<!std::is_same<
   1270                                   NativeT, float>::value>::type* = nullptr>
   1271   Status HandleDot(HloInstruction* dot) {
   1272     return HandleDotSlowPath(dot);
   1273   }
   1274 
   1275   Status HandleDotSlowPath(HloInstruction* dot) {
   1276     auto lhs = dot->operand(0);
   1277     auto rhs = dot->operand(1);
   1278     CHECK(dot->shape().IsArray());
   1279     CHECK(lhs->shape().IsArray());
   1280     CHECK(rhs->shape().IsArray());
   1281 
   1282     const auto& dnums = dot->dot_dimension_numbers();
   1283 
   1284     const auto lhs_rank = lhs->shape().rank();
   1285     const auto rhs_rank = rhs->shape().rank();
   1286 
   1287     CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape()));
   1288     CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape()));
   1289 
   1290     const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
   1291     const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
   1292 
   1293     CHECK_EQ(dnums.lhs_batch_dimensions_size(),
   1294              dnums.rhs_batch_dimensions_size());
   1295 
   1296     DimensionVector lhs_index(lhs_rank);
   1297     DimensionVector rhs_index(rhs_rank);
   1298 
   1299     // result_index_locations[i] contains one or two pointers to the locations
   1300     // in lhs_index or rhs_index where the i'th result index should go.
   1301     absl::InlinedVector<std::pair<int64*, int64*>, kInlineRank>
   1302         result_index_locations;
   1303     result_index_locations.reserve(
   1304         (lhs_rank - dnums.lhs_contracting_dimensions_size()) +
   1305         (rhs_rank - dnums.rhs_contracting_dimensions_size()));
   1306 
   1307     // The first components in the output shape are the LHS and RHS batch
   1308     // dimensions:
   1309     for (int64 i = 0; i < dnums.lhs_batch_dimensions_size(); i++) {
   1310       result_index_locations.push_back(
   1311           {&lhs_index[dnums.lhs_batch_dimensions(i)],
   1312            &rhs_index[dnums.rhs_batch_dimensions(i)]});
   1313     }
   1314 
   1315     // Then we have the LHS and RHS non-contracting dimensions, if any:
   1316     for (int64 i = 0; i < lhs_rank; i++) {
   1317       if (!absl::c_linear_search(dnums.lhs_contracting_dimensions(), i) &&
   1318           !absl::c_linear_search(dnums.lhs_batch_dimensions(), i)) {
   1319         result_index_locations.push_back({&lhs_index[i], nullptr});
   1320       }
   1321     }
   1322     for (int64 i = 0; i < rhs_rank; i++) {
   1323       if (!absl::c_linear_search(dnums.rhs_contracting_dimensions(), i) &&
   1324           !absl::c_linear_search(dnums.rhs_batch_dimensions(), i)) {
   1325         result_index_locations.push_back({&rhs_index[i], nullptr});
   1326       }
   1327     }
   1328 
   1329     absl::InlinedVector<int64, kInlineRank> accumulate_index_sizes;
   1330     accumulate_index_sizes.reserve(dnums.lhs_contracting_dimensions_size());
   1331     absl::InlinedVector<std::pair<int64*, int64*>, kInlineRank>
   1332         accumulate_index_locations;
   1333     accumulate_index_locations.reserve(dnums.lhs_contracting_dimensions_size());
   1334     for (int64 i = 0; i < dnums.lhs_contracting_dimensions_size(); ++i) {
   1335       const int64 lhs_dnum = dnums.lhs_contracting_dimensions(i);
   1336       const int64 rhs_dnum = dnums.rhs_contracting_dimensions(i);
   1337       accumulate_index_locations.push_back(
   1338           {&lhs_index[lhs_dnum], &rhs_index[rhs_dnum]});
   1339       const int64 dim_size = lhs->shape().dimensions(lhs_dnum);
   1340       accumulate_index_sizes.push_back(dim_size);
   1341     }
   1342     const int64 total_contraction_size = Product(accumulate_index_sizes);
   1343     Literal result(dot->shape());
   1344     TF_RETURN_IF_ERROR(
   1345         result.Populate<ReturnT>([&](absl::Span<const int64> result_index) {
   1346           ElementwiseT result_val = static_cast<ElementwiseT>(0);
   1347 
   1348           for (int64 i = 0; i < result_index.size(); i++) {
   1349             *result_index_locations[i].first = result_index[i];
   1350             if (result_index_locations[i].second) {
   1351               *result_index_locations[i].second = result_index[i];
   1352             }
   1353           }
   1354 
   1355           // Accumulates resulting product along the contracted dimension.
   1356           absl::InlinedVector<int64, kInlineRank> accumulate_index(
   1357               accumulate_index_sizes.size(), 0);
   1358           for (int64 k = 0; k < total_contraction_size; k++) {
   1359             for (int64 i = 0; i < accumulate_index_sizes.size(); ++i) {
   1360               *(accumulate_index_locations[i].first) = accumulate_index[i];
   1361               *(accumulate_index_locations[i].second) = accumulate_index[i];
   1362             }
   1363 
   1364             result_val +=
   1365                 static_cast<ElementwiseT>(lhs_literal.Get<ReturnT>(lhs_index)) *
   1366                 static_cast<ElementwiseT>(rhs_literal.Get<ReturnT>(rhs_index));
   1367 
   1368             // If there are no contracting dimension accumulate_index_sizes is
   1369             // empty, do not try to count down from -1 to 0 since it is and
   1370             // infinite loop.
   1371             if (!accumulate_index_sizes.empty()) {
   1372               for (int64 i = accumulate_index_sizes.size() - 1; i >= 0; --i) {
   1373                 int64 value = ++accumulate_index[i];
   1374                 if (value != accumulate_index_sizes[i]) {
   1375                   break;
   1376                 }
   1377                 accumulate_index[i] = 0;
   1378               }
   1379             }
   1380           }
   1381 
   1382           return static_cast<ReturnT>(result_val);
   1383         }));
   1384 
   1385     parent_->evaluated_[dot] = std::move(result);
   1386     return Status::OK();
   1387   }
   1388 
   1389   Status HandlePad(HloInstruction* pad) override {
   1390     CHECK(pad->operand(0)->shape().IsArray());
   1391     // Padding value must be scalar.
   1392     CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape()));
   1393     CHECK_EQ(pad->operand(0)->shape().rank(),
   1394              pad->padding_config().dimensions_size());
   1395 
   1396     TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
   1397                         ShapeInference::InferPadShape(
   1398                             /*operand_shape=*/pad->operand(0)->shape(),
   1399                             /*padding_value_shape=*/pad->operand(1)->shape(),
   1400                             /*padding_config=*/pad->padding_config()));
   1401     CHECK(ShapeUtil::Compatible(pad->shape(), inferred_return_shape))
   1402         << "return shape is set to: " << ShapeUtil::HumanString(pad->shape())
   1403         << " but is inferred to be: "
   1404         << ShapeUtil::HumanString(inferred_return_shape);
   1405 
   1406     // Create new HLO of padded shape with padding value.
   1407     ReturnT scalar =
   1408         parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get<ReturnT>({});
   1409     Literal result(pad->shape());
   1410     TF_RETURN_IF_ERROR(result.Populate<ReturnT>(
   1411         [&scalar](absl::Span<const int64> multi_index) { return scalar; }));
   1412 
   1413     const Literal& evaluated_operand =
   1414         parent_->GetEvaluatedLiteralFor(pad->operand(0));
   1415 
   1416     std::vector<int64> input_index(evaluated_operand.shape().rank(), 0);
   1417     std::vector<int64> target_index(result.shape().rank(), 0);
   1418 
   1419     // Loop through each element of the operand, assign them to the
   1420     // corresponding index of the resulting padded literal.
   1421     const PaddingConfig& pad_config = pad->padding_config();
   1422 
   1423     auto func = [&](absl::Span<const int64> input_index) {
   1424       for (auto i = 0; i < input_index.size(); ++i) {
   1425         // Interior padding occurs logically before edge padding, so in the case
   1426         // of negative edge padding elements are removed from the
   1427         // interior-padded operand.
   1428         target_index[i] =
   1429             pad_config.dimensions(i).edge_padding_low() +
   1430             input_index[i] * (pad_config.dimensions(i).interior_padding() + 1);
   1431 
   1432         // Account for negative low and high padding: skip assignment if the
   1433         // any target index is out of range.
   1434         if (!(target_index[i] >= 0 &&
   1435               target_index[i] < pad->shape().dimensions(i))) {
   1436           return true;
   1437         }
   1438       }
   1439       result.Set<ReturnT>(target_index,
   1440                           evaluated_operand.Get<ReturnT>(input_index));
   1441       return true;
   1442     };
   1443 
   1444     std::vector<int64> zero_base(evaluated_operand.shape().dimensions_size(),
   1445                                  0);
   1446     std::vector<int64> step(evaluated_operand.shape().dimensions_size(), 1);
   1447 
   1448     ShapeUtil::ForEachIndex(
   1449         evaluated_operand.shape(), zero_base,
   1450         AsInt64Slice(evaluated_operand.shape().dimensions()), step, func);
   1451 
   1452     parent_->evaluated_[pad] = std::move(result);
   1453     return Status::OK();
   1454   }
   1455 
   1456   Status HandleDynamicSlice(HloInstruction* dynamic_slice) override {
   1457     auto operand = dynamic_slice->operand(0);
   1458     auto start_indices = dynamic_slice->operand(1);
   1459     auto result_shape = dynamic_slice->shape();
   1460     TF_ASSIGN_OR_RETURN(
   1461         auto inferred_return_shape,
   1462         ShapeInference::InferDynamicSliceShape(
   1463             operand->shape(),
   1464             Cast<HloDynamicSliceInstruction>(dynamic_slice)->index_shapes(),
   1465             dynamic_slice->dynamic_slice_sizes()));
   1466     TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
   1467         << "return shape is set to: " << ShapeUtil::HumanString(result_shape)
   1468         << " but is inferred to be: "
   1469         << ShapeUtil::HumanString(inferred_return_shape);
   1470     TF_RET_CHECK(
   1471         primitive_util::IsIntegralType(start_indices->shape().element_type()));
   1472 
   1473     const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
   1474 
   1475     switch (start_indices->shape().element_type()) {
   1476       case S32: {
   1477         TF_ASSIGN_OR_RETURN(
   1478             parent_->evaluated_[dynamic_slice],
   1479             DynamicSlice<int32>(
   1480                 operand_literal,
   1481                 absl::MakeConstSpan(dynamic_slice->operands()).subspan(1),
   1482                 result_shape));
   1483       } break;
   1484       case S64: {
   1485         TF_ASSIGN_OR_RETURN(
   1486             parent_->evaluated_[dynamic_slice],
   1487             DynamicSlice<int64>(
   1488                 operand_literal,
   1489                 absl::MakeConstSpan(dynamic_slice->operands()).subspan(1),
   1490                 result_shape));
   1491       } break;
   1492       case U32: {
   1493         TF_ASSIGN_OR_RETURN(
   1494             parent_->evaluated_[dynamic_slice],
   1495             DynamicSlice<uint32>(
   1496                 operand_literal,
   1497                 absl::MakeConstSpan(dynamic_slice->operands()).subspan(1),
   1498                 result_shape));
   1499       } break;
   1500       case U64: {
   1501         TF_ASSIGN_OR_RETURN(
   1502             parent_->evaluated_[dynamic_slice],
   1503             DynamicSlice<uint64>(
   1504                 operand_literal,
   1505                 absl::MakeConstSpan(dynamic_slice->operands()).subspan(1),
   1506                 result_shape));
   1507       } break;
   1508       default:
   1509         LOG(FATAL) << "HandleDynamicSlice: unhandled primitive type for "
   1510                       "start_indices: "
   1511                    << PrimitiveType_Name(start_indices->shape().element_type());
   1512     }
   1513 
   1514     return Status::OK();
   1515   }
   1516 
   1517   Status HandleDynamicUpdateSlice(
   1518       HloInstruction* dynamic_update_slice) override {
   1519     auto operand = dynamic_update_slice->operand(0);
   1520     auto update = dynamic_update_slice->operand(1);
   1521     auto start_indices = dynamic_update_slice->operand(2);
   1522     auto result_shape = dynamic_update_slice->shape();
   1523     TF_ASSIGN_OR_RETURN(
   1524         auto inferred_return_shape,
   1525         ShapeInference::InferDynamicUpdateSliceShape(
   1526             operand->shape(), update->shape(),
   1527             Cast<HloDynamicUpdateSliceInstruction>(dynamic_update_slice)
   1528                 ->index_shapes()));
   1529     TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
   1530         << "return shape is set to: " << ShapeUtil::HumanString(result_shape)
   1531         << " but is inferred to be: "
   1532         << ShapeUtil::HumanString(inferred_return_shape);
   1533     TF_RET_CHECK(
   1534         primitive_util::IsIntegralType(start_indices->shape().element_type()));
   1535     TF_RET_CHECK(ShapeUtil::Compatible(result_shape, operand->shape()));
   1536 
   1537     const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
   1538     const Literal& update_literal = parent_->GetEvaluatedLiteralFor(update);
   1539 
   1540     switch (start_indices->shape().element_type()) {
   1541       case S32: {
   1542         TF_ASSIGN_OR_RETURN(
   1543             parent_->evaluated_[dynamic_update_slice],
   1544             DynamicUpdateSlice<int32>(
   1545                 operand_literal, update_literal,
   1546                 absl::MakeConstSpan(dynamic_update_slice->operands())
   1547                     .subspan(2)));
   1548       } break;
   1549       case S64: {
   1550         TF_ASSIGN_OR_RETURN(
   1551             parent_->evaluated_[dynamic_update_slice],
   1552             DynamicUpdateSlice<int64>(
   1553                 operand_literal, update_literal,
   1554                 absl::MakeConstSpan(dynamic_update_slice->operands())
   1555                     .subspan(2)));
   1556       } break;
   1557       case U32: {
   1558         TF_ASSIGN_OR_RETURN(
   1559             parent_->evaluated_[dynamic_update_slice],
   1560             DynamicUpdateSlice<uint32>(
   1561                 operand_literal, update_literal,
   1562                 absl::MakeConstSpan(dynamic_update_slice->operands())
   1563                     .subspan(2)));
   1564       } break;
   1565       case U64: {
   1566         TF_ASSIGN_OR_RETURN(
   1567             parent_->evaluated_[dynamic_update_slice],
   1568             DynamicUpdateSlice<uint64>(
   1569                 operand_literal, update_literal,
   1570                 absl::MakeConstSpan(dynamic_update_slice->operands())
   1571                     .subspan(2)));
   1572       } break;
   1573       default:
   1574         LOG(FATAL) << "HandleDynamicUpdateSlice: unhandled primitive type for "
   1575                       "start_indices: "
   1576                    << PrimitiveType_Name(start_indices->shape().element_type());
   1577     }
   1578 
   1579     return Status::OK();
   1580   }
   1581 
   1582   template <typename NativeT>
   1583   StatusOr<Literal> MapImpl(HloInstruction* map) {
   1584     auto operands = map->operands();
   1585     HloComputation* computation = map->to_apply();
   1586 
   1587     Literal result(map->shape());
   1588 
   1589     HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
   1590     TF_RETURN_IF_ERROR(
   1591         result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
   1592           std::vector<Literal> arg_literals;
   1593           arg_literals.reserve(operands.size());
   1594 
   1595           // Construct scalar literal parameters to be passed to the map
   1596           // computation.
   1597           for (auto operand : operands) {
   1598             const Literal& arg_literal =
   1599                 parent_->GetEvaluatedLiteralFor(operand);
   1600 
   1601             auto curr_val = arg_literal.Get<NativeT>(multi_index);
   1602             auto curr_val_literal = LiteralUtil::CreateR0<NativeT>(curr_val);
   1603 
   1604             arg_literals.push_back(std::move(curr_val_literal));
   1605           }
   1606 
   1607           Literal computed_result =
   1608               embedded_evaluator.Evaluate(*computation, arg_literals)
   1609                   .ConsumeValueOrDie();
   1610           // Clear visit states so that the we can use the evaluate again on
   1611           // the same computation.
   1612           embedded_evaluator.ResetVisitStates();
   1613 
   1614           return computed_result.Get<ReturnT>({});
   1615         }));
   1616     return std::move(result);
   1617   }
   1618 
   1619   Status HandleMap(HloInstruction* map) override {
   1620     switch (map->operand(0)->shape().element_type()) {
   1621       case PRED: {
   1622         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<bool>(map));
   1623         break;
   1624       }
   1625       case U8: {
   1626         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint8>(map));
   1627         break;
   1628       }
   1629       case U32: {
   1630         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint32>(map));
   1631         break;
   1632       }
   1633       case U64: {
   1634         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint64>(map));
   1635         break;
   1636       }
   1637       case S8: {
   1638         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int8>(map));
   1639         break;
   1640       }
   1641       case S32: {
   1642         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int32>(map));
   1643         break;
   1644       }
   1645       case S64: {
   1646         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int64>(map));
   1647         break;
   1648       }
   1649       case F16: {
   1650         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map],
   1651                             MapImpl<Eigen::half>(map));
   1652         break;
   1653       }
   1654       case F32: {
   1655         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<float>(map));
   1656         break;
   1657       }
   1658       case F64: {
   1659         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<double>(map));
   1660         break;
   1661       }
   1662       case C64: {
   1663         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<complex64>(map));
   1664         break;
   1665       }
   1666       case C128: {
   1667         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<complex128>(map));
   1668         break;
   1669       }
   1670       default:
   1671         LOG(FATAL) << "HandleMap: unhandled primitive type for "
   1672                       "input operand: "
   1673                    << PrimitiveType_Name(
   1674                           map->operand(0)->shape().element_type());
   1675     }
   1676 
   1677     return Status::OK();
   1678   }
   1679 
   1680   Status HandleSort(HloInstruction* sort) override {
   1681     return UnsupportedTypeError(sort);
   1682   }
   1683 
   1684   Status HandleReduce(HloInstruction* hlo) override {
   1685     HloReduceInstruction* reduce = Cast<HloReduceInstruction>(hlo);
   1686     int64 num_args = reduce->inputs().size();
   1687     bool has_tuple_output = reduce->shape().IsTuple();
   1688     absl::Span<const int64> dimensions(reduce->dimensions());
   1689     HloComputation* function = reduce->to_apply();
   1690 
   1691     absl::InlinedVector<const Shape*, 1> operand_shapes;
   1692     for (const HloInstruction* operand : reduce->operands()) {
   1693       operand_shapes.push_back(&operand->shape());
   1694     }
   1695     TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
   1696                         ShapeInference::InferReduceShape(
   1697                             operand_shapes,
   1698                             /*dimensions_to_reduce=*/dimensions,
   1699                             /*to_apply=*/function->ComputeProgramShape()));
   1700     TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape))
   1701         << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape())
   1702         << " but is inferred to be: "
   1703         << ShapeUtil::HumanString(inferred_return_shape);
   1704 
   1705     absl::InlinedVector<const Literal*, 1> arg_literals(num_args);
   1706     absl::InlinedVector<const Literal*, 1> init_literals(num_args);
   1707     for (int64 i = 0; i < num_args; ++i) {
   1708       arg_literals[i] = &parent_->GetEvaluatedLiteralFor(reduce->inputs()[i]);
   1709       VLOG(3) << "HandleReduce arg_literal: " << arg_literals[i]->ToString();
   1710       init_literals[i] =
   1711           &parent_->GetEvaluatedLiteralFor(reduce->init_values()[i]);
   1712       VLOG(3) << "HandleReduce init_literal: " << init_literals[i]->ToString();
   1713       TF_RET_CHECK(ShapeUtil::IsScalar(init_literals[i]->shape()));
   1714     }
   1715 
   1716     // All args and results have the same dimensions, so pick an arbitrary one.
   1717     const Shape& arg_shape = arg_literals[0]->shape();
   1718     const Shape& result_shape = reduce->shape().IsTuple()
   1719                                     ? reduce->shape().tuple_shapes(0)
   1720                                     : reduce->shape();
   1721     const auto arg_dimensions = AsInt64Slice(arg_shape.dimensions());
   1722     std::vector<int64> arg_dim_steps(arg_dimensions.size());
   1723     std::vector<int64> arg_dim_counts(arg_dimensions.size());
   1724     for (const int64 dim : dimensions) {
   1725       arg_dim_steps[dim] = 1;
   1726       arg_dim_counts[dim] = arg_dimensions[dim];
   1727     }
   1728 
   1729     // Map each dimension in the result to a dimension in arg that isn't
   1730     // being reduced.
   1731     std::vector<int64> result_to_arg_index;
   1732     for (int64 i = 0; i < arg_dimensions.size(); ++i) {
   1733       if (arg_dim_steps[i] == 0) {
   1734         result_to_arg_index.push_back(i);
   1735       }
   1736     }
   1737 
   1738     HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
   1739     absl::InlinedVector<Literal, 1> results(num_args);
   1740     for (int64 i = 0; i < num_args; ++i) {
   1741       results[i] = Literal(result_shape);
   1742     }
   1743 
   1744     Status eval_status;
   1745     // For each resulting dimension, calculate and assign computed values.
   1746     // This is really wasteful when num_args > 1, since we re-run the
   1747     // reduction num_args time. The alternative is to teach Populate() about
   1748     // tuples, which we should probably do.
   1749     absl::InlinedVector<ReturnT, 1> init_scalars(num_args);
   1750     for (int i = 0; i < num_args; ++i) {
   1751       init_scalars[i] = init_literals[i]->Get<ReturnT>({});
   1752     }
   1753 
   1754     for (int64 input = 0; input < num_args; ++input) {
   1755       TF_RETURN_IF_ERROR(results[input].Populate<ReturnT>(
   1756           [&](absl::Span<const int64> multi_index) {
   1757             if (!eval_status.ok()) {
   1758               return init_scalars[input];
   1759             }
   1760             absl::InlinedVector<ReturnT, 1> result_values(init_scalars.begin(),
   1761                                                           init_scalars.end());
   1762             std::vector<int64> base(arg_dimensions.size());
   1763             for (int64 i = 0; i < multi_index.size(); ++i) {
   1764               base[result_to_arg_index[i]] = multi_index[i];
   1765             }
   1766 
   1767             // When the reduction is addition of floats, accumulate in a double
   1768             // for better precision. Also, avoid creating Literals for the
   1769             // intermediate results; it's much faster.
   1770             if (ShapeUtil::ElementIsFloating(init_literals[0]->shape()) &&
   1771                 IsScalarAdd(function)) {
   1772               CHECK_EQ(num_args, 1);
   1773               double computed_result = 0;
   1774               auto func = [&](absl::Span<const int64> input_index) {
   1775                 computed_result +=
   1776                     GetAsDouble<ReturnT>(*arg_literals[0], input_index);
   1777                 return true;
   1778               };
   1779               ShapeUtil::ForEachIndex(arg_literals[0]->shape(), base,
   1780                                       arg_dim_counts, arg_dim_steps, func);
   1781               return static_cast<ReturnT>(computed_result);
   1782             }
   1783             auto func =
   1784                 [&](absl::Span<const int64> input_index) -> StatusOr<bool> {
   1785               absl::InlinedVector<ReturnT, 1> arg_values(num_args);
   1786               for (int64 i = 0; i < num_args; ++i) {
   1787                 arg_values[i] = arg_literals[i]->Get<ReturnT>(input_index);
   1788               }
   1789 
   1790               // Evaluate computation with specified literal operands.
   1791               absl::InlinedVector<Literal, 1> embedded_operands;
   1792               for (ReturnT value : result_values) {
   1793                 embedded_operands.push_back(
   1794                     LiteralUtil::CreateR0<ReturnT>(value));
   1795               }
   1796               for (ReturnT value : arg_values) {
   1797                 embedded_operands.push_back(
   1798                     LiteralUtil::CreateR0<ReturnT>(value));
   1799               }
   1800               absl::InlinedVector<Literal*, 1> embedded_operands_ptrs(
   1801                   embedded_operands.size());
   1802               std::transform(embedded_operands.begin(), embedded_operands.end(),
   1803                              embedded_operands_ptrs.begin(),
   1804                              [](Literal& literal) { return &literal; });
   1805 
   1806               TF_ASSIGN_OR_RETURN(Literal computed_result,
   1807                                   embedded_evaluator.Evaluate(
   1808                                       *function, embedded_operands_ptrs));
   1809               // Clear visit states so that we can use the evaluator again on
   1810               // the same computation.
   1811               embedded_evaluator.ResetVisitStates();
   1812               // Assign computed result to result_val.
   1813               if (!has_tuple_output) {
   1814                 result_values[0] = computed_result.Get<ReturnT>({});
   1815               } else {
   1816                 for (int64 i = 0; i < num_args; ++i) {
   1817                   result_values[i] = computed_result.Get<ReturnT>(
   1818                       /*multi_index=*/{}, /*shape_index=*/{i});
   1819                 }
   1820               }
   1821               return true;
   1822             };
   1823             // Computes one element of the result, reducing all dimensions that
   1824             // contribute to that element.
   1825             eval_status = ShapeUtil::ForEachIndexWithStatus(
   1826                 arg_shape, base, arg_dim_counts, arg_dim_steps, func);
   1827             return result_values[input];
   1828           }));
   1829     }
   1830     if (!has_tuple_output) {
   1831       parent_->evaluated_[reduce] = std::move(results[0]);
   1832     } else {
   1833       Literal tuple_result(reduce->shape());
   1834       for (int64 i = 0; i < num_args; ++i) {
   1835         TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i}));
   1836       }
   1837       parent_->evaluated_[reduce] = std::move(tuple_result);
   1838     }
   1839     return eval_status;
   1840   }
   1841 
   1842   bool IsScalarAdd(HloComputation* computation) {
   1843     HloInstruction* instruction = computation->root_instruction();
   1844     if (instruction->opcode() == HloOpcode::kAdd &&
   1845         computation->num_parameters() == 2) {
   1846       const HloInstruction* lhs = instruction->operand(0);
   1847       const HloInstruction* rhs = instruction->operand(1);
   1848       return lhs->opcode() == HloOpcode::kParameter &&
   1849              ShapeUtil::IsScalar(lhs->shape()) &&
   1850              rhs->opcode() == HloOpcode::kParameter &&
   1851              ShapeUtil::IsScalar(rhs->shape()) && lhs != rhs;
   1852     }
   1853     return false;
   1854   }
   1855 
   1856   Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override {
   1857     auto operand = select_and_scatter->operand(0);
   1858     auto source = select_and_scatter->operand(1);
   1859     const Window& window = select_and_scatter->window();
   1860 
   1861     const Literal& init_literal =
   1862         parent_->GetEvaluatedLiteralFor(select_and_scatter->operand(2));
   1863     TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
   1864     auto init_scalar = init_literal.Get<ReturnT>({});
   1865 
   1866     Literal result(select_and_scatter->shape());
   1867 
   1868     // Initialize result array with the init value.
   1869     TF_RETURN_IF_ERROR(result.Populate<ReturnT>(
   1870         [&](absl::Span<const int64> output_index) { return init_scalar; }));
   1871 
   1872     std::vector<int64> window_dimension_sizes;
   1873     for (const auto& window_dimension : window.dimensions()) {
   1874       window_dimension_sizes.push_back(window_dimension.size());
   1875     }
   1876     const Shape window_shape = ShapeUtil::MakeShape(
   1877         operand->shape().element_type(), window_dimension_sizes);
   1878 
   1879     HloComputation* select = select_and_scatter->select();
   1880     HloComputation* scatter = select_and_scatter->scatter();
   1881 
   1882     const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
   1883     const Literal& source_literal = parent_->GetEvaluatedLiteralFor(source);
   1884 
   1885     int64 rank = operand_literal.shape().rank();
   1886 
   1887     HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
   1888     DimensionVector source_index(rank, 0);
   1889 
   1890     // Used in the dual IterateThroughWindow lambdas below. Hoisted to avoid
   1891     // dynamic memory allocations.
   1892     auto curr_val_literal = LiteralUtil::CreateR0<ReturnT>(ReturnT());
   1893     auto selected_val_literal = LiteralUtil::CreateR0<ReturnT>(ReturnT());
   1894     auto source_literal_scatter = LiteralUtil::CreateR0<ReturnT>(ReturnT());
   1895     auto scattered_literal = LiteralUtil::CreateR0<ReturnT>(ReturnT());
   1896     do {
   1897       // For each element in `source`, we place a window in `operand`. For each
   1898       // window placement, we iterate inside the window twice:
   1899       //
   1900       // 1. Find the selected index by applying `select` function to all
   1901       // elements. E.g., If the `select` function is GreaterEqual, the first
   1902       // iteration through the window finds the biggest value and returns its
   1903       // index.
   1904       //
   1905       // 2. Using the selected index, scatter value from `source` to result. We
   1906       // do this by iterating through the window, and compare each index with
   1907       // the selected index.
   1908       absl::optional<ReturnT> selected_val;
   1909       absl::optional<std::vector<int64>> selected_index;
   1910 
   1911       IterateThroughWindow(
   1912           window_shape, window, operand_literal.shape(), source_index,
   1913           [&](const std::vector<int64>& operand_index) {
   1914             auto curr_val = operand_literal.Get<ReturnT>(operand_index);
   1915             if (!selected_val) {
   1916               selected_val = curr_val;
   1917               selected_index = operand_index;
   1918             }
   1919             curr_val_literal.Set({}, curr_val);
   1920             selected_val_literal.Set({}, *selected_val);
   1921             Literal computed_result =
   1922                 embedded_evaluator
   1923                     .Evaluate(*select,
   1924                               {&selected_val_literal, &curr_val_literal})
   1925                     .ConsumeValueOrDie();
   1926             bool selected = !computed_result.Get<bool>({});
   1927             if (selected) {
   1928               selected_val = curr_val;
   1929               selected_index = operand_index;
   1930             }
   1931             embedded_evaluator.ResetVisitStates();
   1932           });
   1933 
   1934       IterateThroughWindow(
   1935           window_shape, window, operand_literal.shape(), source_index,
   1936           [&](const std::vector<int64>& operand_index) {
   1937             if (std::equal(operand_index.begin(), operand_index.end(),
   1938                            selected_index->begin())) {
   1939               auto source = source_literal.Get<ReturnT>(source_index);
   1940               auto scattered = result.Get<ReturnT>(operand_index);
   1941               source_literal_scatter.Set({}, source);
   1942               scattered_literal.Set({}, scattered);
   1943               Literal computed_result =
   1944                   embedded_evaluator
   1945                       .Evaluate(*scatter,
   1946                                 {&source_literal_scatter, &scattered_literal})
   1947                       .ConsumeValueOrDie();
   1948               result.Set(operand_index, computed_result.Get<ReturnT>({}));
   1949               // Clear visit states so that the we can use the evaluator again
   1950               // on the same computation.
   1951               embedded_evaluator.ResetVisitStates();
   1952             }
   1953           });
   1954     } while (
   1955         IndexUtil::BumpIndices(source->shape(), absl::MakeSpan(source_index)));
   1956 
   1957     parent_->evaluated_[select_and_scatter] = std::move(result);
   1958     return Status::OK();
   1959   }
   1960 
   1961   Status HandleReduceWindow(HloInstruction* reduce_window) override {
   1962     auto operand = reduce_window->operand(0);
   1963     const Window& window = reduce_window->window();
   1964     HloComputation* function = reduce_window->to_apply();
   1965     TF_ASSIGN_OR_RETURN(
   1966         auto inferred_return_shape,
   1967         ShapeInference::InferReduceWindowShape(
   1968             /*operand_shape=*/reduce_window->operand(0)->shape(),
   1969             /*init_value=*/reduce_window->operand(1)->shape(), window,
   1970             /*to_apply_shape=*/function->ComputeProgramShape()));
   1971     TF_RET_CHECK(
   1972         ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape))
   1973         << "return shape is set to: "
   1974         << ShapeUtil::HumanStringWithLayout(reduce_window->shape())
   1975         << " but is inferred to be: "
   1976         << ShapeUtil::HumanStringWithLayout(inferred_return_shape);
   1977 
   1978     const Literal& operand_literal =
   1979         parent_->GetEvaluatedLiteralFor(reduce_window->operand(0));
   1980     VLOG(3) << "HandleReduceWindow arg_literal: " << operand_literal.ToString();
   1981     const Literal& init_literal =
   1982         parent_->GetEvaluatedLiteralFor(reduce_window->operand(1));
   1983     VLOG(3) << "HandleReduceWindow init_literal: " << init_literal.ToString();
   1984     TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
   1985     auto init_scalar = init_literal.Get<ReturnT>({});
   1986 
   1987     // Creates a Shape object from window, for iteration below.
   1988     std::vector<int64> window_dimension_sizes;
   1989     for (const auto& window_dimension : window.dimensions()) {
   1990       window_dimension_sizes.push_back(window_dimension.size());
   1991     }
   1992     const Shape window_shape = ShapeUtil::MakeShape(
   1993         operand->shape().element_type(), window_dimension_sizes);
   1994 
   1995     DimensionVector window_index(window.dimensions_size());
   1996     DimensionVector operand_index(operand_literal.shape().rank());
   1997 
   1998     HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
   1999     Literal result(reduce_window->shape());
   2000     // For each resulting dimension, calculate and assign computed value.
   2001     TF_RETURN_IF_ERROR(
   2002         result.Populate<ReturnT>([&](absl::Span<const int64> output_index) {
   2003           ReturnT result_val = init_scalar;
   2004 
   2005           std::fill(window_index.begin(), window_index.end(), 0);
   2006           std::fill(operand_index.begin(), operand_index.end(), 0);
   2007 
   2008           IterateThroughWindow(
   2009               window_shape, window, operand_literal.shape(), output_index,
   2010               [&](const std::vector<int64>& operand_index) {
   2011                 auto curr_val = operand_literal.Get<ReturnT>(operand_index);
   2012 
   2013                 // Evaluate computation with specified literal operands.
   2014                 const auto curr_val_literal =
   2015                     LiteralUtil::CreateR0<ReturnT>(curr_val);
   2016                 const auto result_val_literal =
   2017                     LiteralUtil::CreateR0<ReturnT>(result_val);
   2018                 Literal computed_result =
   2019                     embedded_evaluator
   2020                         .Evaluate(*function,
   2021                                   {&result_val_literal, &curr_val_literal})
   2022                         .ConsumeValueOrDie();
   2023 
   2024                 // Clear visit states so that the we can use the evaluate again
   2025                 // on the same computation.
   2026                 embedded_evaluator.ResetVisitStates();
   2027 
   2028                 result_val = computed_result.Get<ReturnT>({});
   2029               });
   2030 
   2031           return result_val;
   2032         }));
   2033 
   2034     parent_->evaluated_[reduce_window] = std::move(result);
   2035     return Status::OK();
   2036   }
   2037 
   2038   // Reshapes the scatter indices input to have a trailing degenerate `1`
   2039   // dimension if necessary.  Hands over the ownership of the newly created
   2040   // literal (if there is one) to `reshaped_indices`.
   2041   StatusOr<std::reference_wrapper<const Literal>> ReshapedScatterIndices(
   2042       int64 index_vector_dim, const Literal& indices,
   2043       Literal* reshaped_indices) {
   2044     if (indices.shape().dimensions_size() != index_vector_dim) {
   2045       return std::cref(indices);
   2046     }
   2047 
   2048     std::vector<int64> new_shape(indices.shape().dimensions().begin(),
   2049                                  indices.shape().dimensions().end());
   2050     new_shape.push_back(1);
   2051     TF_ASSIGN_OR_RETURN(*reshaped_indices, indices.Reshape(new_shape));
   2052     return std::cref(*reshaped_indices);
   2053   }
   2054 
   2055   // Returns an ShapeUtil::IndexIterationSpace that iterates over the update
   2056   // scatter dimensions while keeping the rest of the update dimensions clamped
   2057   // to 0.
   2058   ShapeUtil::IndexIterationSpace IterationSpaceForUpdateScatterIndices(
   2059       const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) {
   2060     int64 updates_rank = updates_shape.dimensions_size();
   2061     std::vector<int64> index_base(updates_rank, 0);
   2062     std::vector<int64> index_count(updates_rank, 1);
   2063     for (int64 i = 0; i < updates_rank; i++) {
   2064       bool is_update_scatter_dim =
   2065           !absl::c_binary_search(dim_numbers.update_window_dims(), i);
   2066       if (is_update_scatter_dim) {
   2067         index_count[i] = updates_shape.dimensions(i);
   2068       }
   2069     }
   2070     return {std::move(index_base), std::move(index_count),
   2071             std::vector<int64>(updates_rank, 1)};
   2072   }
   2073 
   2074   // Return an ShapeUtil::IndexIterationSpace that iterates over the update
   2075   // window dimensions while keeping the rest of the update dimensions clamped
   2076   // to 0.
   2077   ShapeUtil::IndexIterationSpace IterationSpaceForUpdateWindowIndices(
   2078       const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) {
   2079     int64 updates_rank = updates_shape.dimensions_size();
   2080     std::vector<int64> index_base(updates_rank, 0);
   2081     std::vector<int64> index_count(updates_rank, 1);
   2082     for (int64 i = 0; i < updates_rank; i++) {
   2083       bool is_update_window_dim =
   2084           absl::c_binary_search(dim_numbers.update_window_dims(), i);
   2085       if (is_update_window_dim) {
   2086         index_count[i] = updates_shape.dimensions(i);
   2087       }
   2088     }
   2089     return {std::move(index_base), std::move(index_count),
   2090             std::vector<int64>(updates_rank, 1)};
   2091   }
   2092 
   2093   // This functor computes the contribution of scatter_indices to an input index
   2094   // corresponding to an update index.  That is, given an update index I, it
   2095   // picks out the scatter indices in I and uses them to look up a scatter
   2096   // index, S, from the scatter indices tensor, and expands S into the input
   2097   // space according to scatter_dims_to_operand_dims.
   2098   //
   2099   // This is similar to the class HloEvaluator::OutputGatherIndexToInputIndex
   2100   // that does the corresponding function for Gather.
   2101   class UpdateScatterIndexToInputIndex {
   2102    public:
   2103     // The constructor does some setup work that is amortized across all
   2104     // iterations.
   2105     explicit UpdateScatterIndexToInputIndex(
   2106         const ScatterDimensionNumbers* dim_numbers, const Shape& input_shape,
   2107         const Shape& updates_shape, const Literal* scatter_indices)
   2108         : dim_numbers_(*dim_numbers), scatter_indices_(*scatter_indices) {
   2109       for (int64 i = 0; i < updates_shape.dimensions_size(); i++) {
   2110         update_dim_is_scatter_dims_.push_back(
   2111             !absl::c_binary_search(dim_numbers_.update_window_dims(), i));
   2112       }
   2113 
   2114       for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
   2115         int64 index_of_input_dim_in_index_vector =
   2116             FindIndex(dim_numbers_.scatter_dims_to_operand_dims(), i);
   2117         if (index_of_input_dim_in_index_vector ==
   2118             dim_numbers_.scatter_dims_to_operand_dims_size()) {
   2119           input_dim_value_to_index_vector_.push_back(-1);
   2120         } else {
   2121           input_dim_value_to_index_vector_.push_back(
   2122               index_of_input_dim_in_index_vector);
   2123         }
   2124       }
   2125 
   2126       index_vector_index_.resize(scatter_indices_.shape().dimensions_size());
   2127       input_index_.resize(input_shape.dimensions_size());
   2128       int64 index_vector_size =
   2129           scatter_indices_.shape().dimensions(dim_numbers_.index_vector_dim());
   2130       index_vector_.resize(index_vector_size);
   2131     }
   2132 
   2133     // Returns the contribution of scatter_indices to the input index
   2134     // corresponding to update_index.  See scatter_inner_loop_body.
   2135     //
   2136     // This is conceptually  a stateless transformation from update_index to the
   2137     // scatter input index, but:
   2138     //
   2139     //  - Instead of allocating memory to represent the scatter input index on
   2140     //    every invocation we reuse the same storage for the result
   2141     //    (input_index_), mutating it in place.
   2142     //  - Instead of allocating buffers for temporary values like
   2143     //    index_vector_index_ and index_vector on every invocation, we reuse the
   2144     //    same storage for all invocations.
   2145     //
   2146     // This returns a Span into memory owned by the class.
   2147     StatusOr<absl::Span<const int64>> operator()(
   2148         absl::Span<const int64> update_index) {
   2149       PropagateUpdateIndexScatterDimsToIndexVectorIndex(update_index);
   2150       TF_RETURN_IF_ERROR(FetchIndexVector());
   2151       PropagateIndexVectorToInputIndex();
   2152       return absl::Span<const int64>(input_index_);
   2153     }
   2154 
   2155    private:
   2156     // Propagates the scatter index dimensions from the update index into
   2157     // index_vector_index_ by mutating index_vector_index_ in place.  Does not
   2158     // update the dim_numbers.index_vector_dim() dimension -- that's the
   2159     // dimension we iterate over in FetchIndexVector.
   2160     void PropagateUpdateIndexScatterDimsToIndexVectorIndex(
   2161         absl::Span<const int64> update_index) {
   2162       int64 index_vector_index_i = 0;
   2163       for (int64 i = 0, e = update_index.size(); i < e; i++) {
   2164         if (!update_dim_is_scatter_dims_[i]) {
   2165           continue;
   2166         }
   2167 
   2168         if (index_vector_index_i == dim_numbers_.index_vector_dim()) {
   2169           index_vector_index_i++;
   2170         }
   2171 
   2172         index_vector_index_[index_vector_index_i++] = update_index[i];
   2173       }
   2174     }
   2175 
   2176     // Populates index_vector_ by iterating over scatter_indices_ according to
   2177     // index_vector_index_.
   2178     Status FetchIndexVector() {
   2179       int64 index_vector_dim = dim_numbers_.index_vector_dim();
   2180       for (int64 i = 0, e = index_vector_.size(); i < e; i++) {
   2181         index_vector_index_[index_vector_dim] = i;
   2182         TF_ASSIGN_OR_RETURN(index_vector_[i], scatter_indices_.GetIntegralAsS64(
   2183                                                   index_vector_index_));
   2184       }
   2185       return Status::OK();
   2186     }
   2187 
   2188     // Populates input_index_.
   2189     void PropagateIndexVectorToInputIndex() {
   2190       for (int64 i = 0, e = input_index_.size(); i < e; i++) {
   2191         if (input_dim_value_to_index_vector_[i] != -1) {
   2192           input_index_[i] = index_vector_[input_dim_value_to_index_vector_[i]];
   2193         }
   2194 
   2195         // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
   2196         // remains 0, as set by the constructor.
   2197       }
   2198     }
   2199 
   2200     // input_dim_value_to_index_vector_[i] tells us how to compute dimension i
   2201     // of the input index from the index vector.  See
   2202     // PropagateIndexVectorToInputIndex.
   2203     std::vector<int64> input_dim_value_to_index_vector_;
   2204 
   2205     // update_dim_is_scatter_dims_[i] is true iff the update index i is a
   2206     // scatter dimension.
   2207     std::vector<bool> update_dim_is_scatter_dims_;
   2208 
   2209     // The buffer into which we construct an index into scatter_indices_ to
   2210     // fetch the index vector.
   2211     std::vector<int64> index_vector_index_;
   2212 
   2213     // The index vector fetched from scatter_indices_.
   2214     std::vector<int64> index_vector_;
   2215 
   2216     // The result computed by this functor.  operator() returns a Span
   2217     // into this vector.
   2218     std::vector<int64> input_index_;
   2219 
   2220     const ScatterDimensionNumbers& dim_numbers_;
   2221     const Literal& scatter_indices_;
   2222   };
   2223 
   2224   // This functor computes the contribution of the window indices in an update
   2225   // index to an input index.  That is, given an update index I it picks out the
   2226   // update window indices in I and expands it into a window index into the
   2227   // input shape.
   2228   //
   2229   // This is similar to the class HloEvaluator::OutputWindowIndexToInputIndex
   2230   // that does the corresponding function for Gather.
   2231   class UpdateWindowIndexToInputIndex {
   2232    public:
   2233     // The constructor does some setup work that is amortized across all
   2234     // iterations.
   2235     explicit UpdateWindowIndexToInputIndex(
   2236         const ScatterDimensionNumbers& dim_numbers, const Shape& input_shape,
   2237         const Shape& updates_shape) {
   2238       std::vector<int64> window_index_to_update_index;
   2239       int64 update_index_count = 0;
   2240       for (int64 i = 0; i < updates_shape.dimensions_size(); i++) {
   2241         if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) {
   2242           window_index_to_update_index.push_back(update_index_count++);
   2243         } else {
   2244           update_index_count++;
   2245         }
   2246       }
   2247 
   2248       int64 window_dim_count = 0;
   2249       for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
   2250         if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) {
   2251           input_dim_value_to_update_index_.push_back(-1);
   2252         } else {
   2253           input_dim_value_to_update_index_.push_back(
   2254               window_index_to_update_index[window_dim_count++]);
   2255         }
   2256       }
   2257 
   2258       input_index_.resize(input_shape.dimensions_size());
   2259     }
   2260 
   2261     // Returns the contribution of the window indices to the input index
   2262     // corresponding to update_index.  See scatter_inner_loop_body.
   2263     //
   2264     // This is conceptually a stateless transformation from update_index to the
   2265     // window input index, but instead of allocating memory to represent the
   2266     // scatter input index on every invocation we reuse the same storage for the
   2267     // result (input_index_), mutating it in place.
   2268     //
   2269     // This returns a Span into memory owned by the class.
   2270     StatusOr<absl::Span<const int64>> operator()(
   2271         absl::Span<const int64> update_index) {
   2272       PropagateUpdateIndexWindowDimsToInputIndex(update_index);
   2273       return absl::Span<const int64>(input_index_);
   2274     }
   2275 
   2276     // Returns for a given 'input_dim' the corresponding update dimension index,
   2277     // or -1 if 'input_dim' is an elided window dimension.
   2278     int64 input_dim_value_to_update_index(int64 input_dim) {
   2279       return input_dim_value_to_update_index_[input_dim];
   2280     }
   2281 
   2282    private:
   2283     // Propagates window dimensions from the update index to input_index_ by
   2284     // mutating input_index_ in place.
   2285     void PropagateUpdateIndexWindowDimsToInputIndex(
   2286         absl::Span<const int64> update_index) {
   2287       for (int64 i = 0, e = input_index_.size(); i < e; i++) {
   2288         if (input_dim_value_to_update_index_[i] != -1) {
   2289           input_index_[i] = update_index[input_dim_value_to_update_index_[i]];
   2290         }
   2291 
   2292         // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
   2293         // remains 0, as set by the constructor.
   2294       }
   2295     }
   2296 
   2297     // input_dim_value_to_index_vector_[i] tells us how to compute dimension i
   2298     // of the input index from the update index. See
   2299     // PropagateUpdateIndexWindowDimsToInputIndex.
   2300     std::vector<int64> input_dim_value_to_update_index_;
   2301 
   2302     // The result computed by this functor.  operator() returns a Span
   2303     // into this vector.
   2304     std::vector<int64> input_index_;
   2305   };
   2306 
   2307   Status HandleScatter(HloInstruction* scatter) override {
   2308     const ScatterDimensionNumbers& dim_numbers =
   2309         scatter->scatter_dimension_numbers();
   2310     const Literal& operand =
   2311         parent_->GetEvaluatedLiteralFor(scatter->operand(0));
   2312     Literal reshaped_scatter_indices;
   2313     TF_ASSIGN_OR_RETURN(const Literal& scatter_indices,
   2314                         ReshapedScatterIndices(dim_numbers.index_vector_dim(),
   2315                                                parent_->GetEvaluatedLiteralFor(
   2316                                                    scatter->operand(1)),
   2317                                                &reshaped_scatter_indices));
   2318     const Literal& updates =
   2319         parent_->GetEvaluatedLiteralFor(scatter->operand(2));
   2320     const Shape& updates_shape = updates.shape();
   2321     const Shape& operand_shape = operand.shape();
   2322 
   2323     ShapeUtil::IndexIterationSpace scatter_indices_iteration_space =
   2324         IterationSpaceForUpdateScatterIndices(updates_shape, dim_numbers);
   2325     ShapeUtil::IndexIterationSpace window_indices_iteration_space =
   2326         IterationSpaceForUpdateWindowIndices(updates_shape, dim_numbers);
   2327 
   2328     std::vector<int64> input_index(operand_shape.dimensions_size());
   2329     std::vector<int64> update_index(updates_shape.dimensions_size());
   2330     std::vector<int64> input_scatter_index_clamped(
   2331         operand_shape.dimensions_size());
   2332 
   2333     UpdateScatterIndexToInputIndex update_scatter_index_to_input_index(
   2334         &scatter->scatter_dimension_numbers(), /*input_shape=*/operand_shape,
   2335         updates_shape, &scatter_indices);
   2336     UpdateWindowIndexToInputIndex update_window_index_to_input_index(
   2337         scatter->scatter_dimension_numbers(), /*input_shape=*/operand_shape,
   2338         updates_shape);
   2339 
   2340     // Initialize the result with the operand. This makes it easier to handle
   2341     // the updates even when the indices are repeated.
   2342     Literal result = operand.Clone();
   2343     HloEvaluator embedded_evaluator;
   2344     auto scatter_inner_loop_body =
   2345         [&](absl::Span<const int64> update_window_index,
   2346             absl::Span<const int64> input_scatter_index,
   2347             absl::Span<const int64> update_scatter_index) -> StatusOr<bool> {
   2348       TF_ASSIGN_OR_RETURN(
   2349           absl::Span<const int64> input_window_index,
   2350           update_window_index_to_input_index(update_window_index));
   2351       for (int i = 0, e = update_index.size(); i < e; i++) {
   2352         update_index[i] = update_scatter_index[i] + update_window_index[i];
   2353         DCHECK_LT(update_index[i], updates_shape.dimensions(i));
   2354       }
   2355       for (int i = 0, e = input_scatter_index.size(); i < e; i++) {
   2356         int64 update_dim =
   2357             update_window_index_to_input_index.input_dim_value_to_update_index(
   2358                 i);
   2359         // If 'update_dim' is -1, it means 'i' is an elided window dim. This
   2360         // means we set the iteration index to 0, so for the purpose of the
   2361         // following calculations we can consider the update dimension size to
   2362         // be 1.
   2363         int64 update_dim_size =
   2364             update_dim == -1 ? 1 : updates_shape.dimensions(update_dim);
   2365         // If any part of the update region is out-of-bounds, then do not
   2366         // perform any update on the input.
   2367         if ((input_scatter_index[i] < 0) ||
   2368             (input_scatter_index[i] >
   2369              operand_shape.dimensions(i) - update_dim_size)) {
   2370           return true;
   2371         }
   2372       }
   2373       for (int i = 0, e = input_index.size(); i < e; i++) {
   2374         input_index[i] = input_scatter_index[i] + input_window_index[i];
   2375       }
   2376 
   2377       auto result_value_literal =
   2378           LiteralUtil::CreateR0<ReturnT>(result.Get<ReturnT>(input_index));
   2379       auto update_value_literal =
   2380           LiteralUtil::CreateR0<ReturnT>(updates.Get<ReturnT>(update_index));
   2381       Literal updated_result =
   2382           embedded_evaluator
   2383               .Evaluate(*scatter->to_apply(),
   2384                         {&result_value_literal, &update_value_literal})
   2385               .ConsumeValueOrDie();
   2386       // Clear visit states so that the we can use the evaluate again on the
   2387       // same computation.
   2388       embedded_evaluator.ResetVisitStates();
   2389       result.Set<ReturnT>(input_index, updated_result.Get<ReturnT>({}));
   2390       return true;
   2391     };
   2392 
   2393     auto scatter_outer_loop_body =
   2394         [&](absl::Span<const int64> update_scatter_index) -> StatusOr<bool> {
   2395       TF_ASSIGN_OR_RETURN(
   2396           absl::Span<const int64> input_scatter_index,
   2397           update_scatter_index_to_input_index(update_scatter_index));
   2398       TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
   2399           updates_shape, window_indices_iteration_space,
   2400           [&](absl::Span<const int64> update_window_index) {
   2401             return scatter_inner_loop_body(
   2402                 update_window_index, input_scatter_index, update_scatter_index);
   2403           }));
   2404       return true;
   2405     };
   2406 
   2407     TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
   2408         updates_shape, scatter_indices_iteration_space,
   2409         scatter_outer_loop_body));
   2410     parent_->evaluated_[scatter] = std::move(result);
   2411     return Status::OK();
   2412   }
   2413 
   2414   Status HandleSlice(HloInstruction* slice) override {
   2415     auto operand = slice->operand(0);
   2416     const Shape& shape = slice->shape();
   2417     TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
   2418                         ShapeInference::InferSliceShape(
   2419                             operand->shape(), slice->slice_starts(),
   2420                             slice->slice_limits(), slice->slice_strides()));
   2421     TF_RET_CHECK(ShapeUtil::Compatible(shape, inferred_return_shape))
   2422         << "return shape set to: " << ShapeUtil::HumanString(shape)
   2423         << " but is inferred to be: "
   2424         << ShapeUtil::HumanString(inferred_return_shape);
   2425 
   2426     const int64 rank = operand->shape().rank();
   2427     const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
   2428     auto func = [&](absl::Span<const int64> out_index) {
   2429       DimensionVector operand_index(rank);
   2430       for (int64 i = 0; i < rank; ++i) {
   2431         operand_index[i] =
   2432             slice->slice_starts(i) + out_index[i] * slice->slice_strides(i);
   2433       }
   2434       return operand_literal.Get<ReturnT>(operand_index);
   2435     };
   2436 
   2437     Literal result(shape);
   2438     TF_RETURN_IF_ERROR(result.Populate<ReturnT>(func));
   2439     parent_->evaluated_[slice] = std::move(result);
   2440     return Status::OK();
   2441   }
   2442 
   2443   // Enable CLZ only for int32, uint32, int64 and uint64.
   2444   template <
   2445       typename NativeT,
   2446       typename std::enable_if<
   2447           (std::is_floating_point<NativeT>::value ||
   2448            std::is_integral<NativeT>::value || is_complex_t<NativeT>::value) &&
   2449           !(std::is_same<NativeT, uint32>::value ||
   2450             std::is_same<NativeT, int32>::value ||
   2451             std::is_same<NativeT, int64>::value ||
   2452             std::is_same<NativeT, uint64>::value)>::type* = nullptr>
   2453   Status HandleClz(HloInstruction* clz) {
   2454     return UnsupportedTypeError(clz);
   2455   }
   2456 
   2457   template <typename NativeT,
   2458             typename std::enable_if<
   2459                 std::is_same<NativeT, uint32>::value ||
   2460                 std::is_same<NativeT, int32>::value>::type* = nullptr>
   2461   Status HandleClz(HloInstruction* clz) {
   2462     TF_ASSIGN_OR_RETURN(parent_->evaluated_[clz],
   2463                         ElementWiseUnaryOp(clz, [](ElementwiseT elem_operand) {
   2464                           return 31 - tensorflow::Log2Floor(elem_operand);
   2465                         }));
   2466     return Status::OK();
   2467   }
   2468 
   2469   template <typename NativeT,
   2470             typename std::enable_if<
   2471                 std::is_same<NativeT, uint64>::value ||
   2472                 std::is_same<NativeT, int64>::value>::type* = nullptr>
   2473   Status HandleClz(HloInstruction* clz) {
   2474     TF_ASSIGN_OR_RETURN(parent_->evaluated_[clz],
   2475                         ElementWiseUnaryOp(clz, [](ElementwiseT elem_operand) {
   2476                           return 63 - tensorflow::Log2Floor64(elem_operand);
   2477                         }));
   2478     return Status::OK();
   2479   }
   2480 
   2481   Status HandleClz(HloInstruction* clz) override {
   2482     return HandleClz<ElementwiseT>(clz);
   2483   }
   2484 
   2485   template <typename NativeT, typename std::enable_if<std::is_floating_point<
   2486                                   NativeT>::value>::type* = nullptr>
   2487   Status HandleSin(HloInstruction* sin) {
   2488     TF_ASSIGN_OR_RETURN(parent_->evaluated_[sin],
   2489                         ElementWiseUnaryOp(sin, [](ElementwiseT elem_operand) {
   2490                           return std::sin(elem_operand);
   2491                         }));
   2492     return Status::OK();
   2493   }
   2494 
   2495   template <
   2496       typename NativeT,
   2497       typename std::enable_if<std::is_integral<NativeT>::value ||
   2498                               is_complex_t<NativeT>::value>::type* = nullptr>
   2499   Status HandleSin(HloInstruction* sin) {
   2500     return UnsupportedTypeError(sin);
   2501   }
   2502 
   2503   Status HandleSin(HloInstruction* sin) override {
   2504     return HandleSin<ElementwiseT>(sin);
   2505   }
   2506 
   2507   template <typename NativeT, typename std::enable_if<std::is_floating_point<
   2508                                   NativeT>::value>::type* = nullptr>
   2509   Status HandleCos(HloInstruction* cos) {
   2510     TF_ASSIGN_OR_RETURN(parent_->evaluated_[cos],
   2511                         ElementWiseUnaryOp(cos, [](ElementwiseT elem_operand) {
   2512                           return std::cos(elem_operand);
   2513                         }));
   2514     return Status::OK();
   2515   }
   2516 
   2517   template <
   2518       typename NativeT,
   2519       typename std::enable_if<std::is_integral<NativeT>::value ||
   2520                               is_complex_t<NativeT>::value>::type* = nullptr>
   2521   Status HandleCos(HloInstruction* cos) {
   2522     return UnsupportedTypeError(cos);
   2523   }
   2524 
   2525   Status HandleCos(HloInstruction* cos) override {
   2526     return HandleCos<ElementwiseT>(cos);
   2527   }
   2528 
   2529   template <typename NativeT, typename std::enable_if<std::is_same<
   2530                                   float, NativeT>::value>::type* = nullptr>
   2531   Status HandleReducePrecision(HloInstruction* reduce_precision) {
   2532     TF_ASSIGN_OR_RETURN(
   2533         parent_->evaluated_[reduce_precision],
   2534         ElementWiseUnaryOp(reduce_precision, [reduce_precision](
   2535                                                  ElementwiseT elem) {
   2536           uint32_t value_as_int = absl::bit_cast<uint32_t>(elem);
   2537           const uint32_t mantissa_bits = reduce_precision->mantissa_bits();
   2538           const uint32_t exponent_bits = reduce_precision->exponent_bits();
   2539 
   2540           // Code is based on the CPU/GPU implementation in LLVM-emitting code.
   2541           //
   2542           // Bits in float type:
   2543           //   mantissa : bits [0:22]
   2544           //   exponent : bits [23:30]
   2545           //   sign     : bits [31]
   2546           if (mantissa_bits < 23) {
   2547             const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits);
   2548 
   2549             // Compute rounding bias for round-to-nearest with ties to even.
   2550             // This is equal to a base value of 0111... plus one bit if the last
   2551             // remaining mantissa bit is 1.
   2552             const uint32_t base_rounding_bias =
   2553                 (last_mantissa_bit_mask >> 1) - 1;
   2554             const uint32_t x_last_mantissa_bit =
   2555                 (value_as_int & last_mantissa_bit_mask) >> (23 - mantissa_bits);
   2556             const uint32_t x_rounding_bias =
   2557                 x_last_mantissa_bit + base_rounding_bias;
   2558 
   2559             // Add rounding bias, and mask out truncated bits.  Note that the
   2560             // case where adding the rounding bias overflows into the exponent
   2561             // bits is correct; the non-masked mantissa bits will all be zero,
   2562             // and the exponent will be incremented by one.
   2563             const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1);
   2564             value_as_int = value_as_int + x_rounding_bias;
   2565             value_as_int = value_as_int & truncation_mask;
   2566           }
   2567           if (exponent_bits < 8) {
   2568             // Masks for f32 values.
   2569             const uint32_t f32_sign_bit_mask = 1u << 31;
   2570             const uint32_t f32_exp_bits_mask = 0xffu << 23;
   2571 
   2572             // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the
   2573             // most- significant bit -- is equal to 1.0f for all exponent sizes.
   2574             // Adding 2^(n-1)-1 to this gives us the highest non-infinite
   2575             // exponent for a bit- size of n, and subtracting 2^(n-1)-1 from
   2576             // this gives us the lowest' exponent (corresponding to 0.0f).
   2577             //
   2578             // Thus, the f32 exponent corresponding to the highest non-infinite
   2579             // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
   2580             // exponent corresponding to the lowest exponent for a bit size of n
   2581             // is (2^7-1) - 2^(n-1)-1.
   2582             //
   2583             // Note that we have already checked that exponents_bits >= 1.
   2584             const uint32_t f32_exponent_bias = (1 << 7) - 1;
   2585             const uint32_t reduced_exponent_bias =
   2586                 (1 << (exponent_bits - 1)) - 1;
   2587             const uint32_t reduced_max_exponent =
   2588                 f32_exponent_bias + reduced_exponent_bias;
   2589             const uint32_t reduced_min_exponent =
   2590                 f32_exponent_bias - reduced_exponent_bias;
   2591 
   2592             // Do we overflow or underflow?
   2593             const uint32_t x_exponent = value_as_int & f32_exp_bits_mask;
   2594             const bool x_overflows = x_exponent > (reduced_max_exponent << 23);
   2595             const bool x_underflows =
   2596                 x_exponent <= (reduced_min_exponent << 23);
   2597 
   2598             // Compute appropriately-signed values of zero and infinity.
   2599             const uint32_t x_signed_zero = value_as_int & f32_sign_bit_mask;
   2600             const uint32_t x_signed_inf = x_signed_zero | f32_exp_bits_mask;
   2601 
   2602             // Force to zero or infinity if overflow or underflow.  (Note that
   2603             // this truncates all denormal values to zero, rather than rounding
   2604             // them.)
   2605             value_as_int = x_overflows ? x_signed_inf : value_as_int;
   2606             value_as_int = x_underflows ? x_signed_zero : value_as_int;
   2607           }
   2608 
   2609           float reduced_result = absl::bit_cast<float>(value_as_int);
   2610           if (std::isnan(elem)) {
   2611             reduced_result = mantissa_bits > 0
   2612                                  ? elem
   2613                                  : std::numeric_limits<float>::infinity();
   2614           }
   2615           return reduced_result;
   2616         }));
   2617     return Status::OK();
   2618   }
   2619 
   2620   template <typename NativeT, typename std::enable_if<std::is_same<
   2621                                   double, NativeT>::value>::type* = nullptr>
   2622   Status HandleReducePrecision(HloInstruction* reduce_precision) {
   2623     return InvalidArgument("Double is not supported for reduce precision");
   2624   }
   2625 
   2626   template <
   2627       typename NativeT,
   2628       typename std::enable_if<std::is_integral<NativeT>::value ||
   2629                               is_complex_t<NativeT>::value>::type* = nullptr>
   2630   Status HandleReducePrecision(HloInstruction* reduce_precision) {
   2631     return UnsupportedTypeError(reduce_precision);
   2632   }
   2633 
   2634   Status HandleReducePrecision(HloInstruction* reduce_precision) override {
   2635     return HandleReducePrecision<ElementwiseT>(reduce_precision);
   2636   }
   2637 
   2638   template <
   2639       typename NativeT,
   2640       typename std::enable_if<
   2641           std::is_same<NativeT, bfloat16>::value ||
   2642           std::is_same<NativeT, Eigen::half>::value ||
   2643           std::is_integral<NativeT>::value || is_complex_t<NativeT>::value ||
   2644           std::is_floating_point<NativeT>::value>::type* = nullptr>
   2645   Status HandleIota(HloInstruction* instruction) {
   2646     auto* iota = Cast<HloIotaInstruction>(instruction);
   2647     const int64 iota_size = iota->shape().dimensions(iota->iota_dimension());
   2648     // Avoid using std::vector since std::vector<bool> does not convert to
   2649     // absl::Span<bool>.
   2650     absl::InlinedVector<NativeT, 1> data(iota_size);
   2651     // We don't use std::iota for two reasons:
   2652     //
   2653     // (1) std:iota does not support bfloat16 and float16.
   2654     //
   2655     // (2) std::iota saturates for floating point types when the value is not
   2656     //     representable, but the definition of HLO iota is the value as a
   2657     //     64-bit integer cast to the native type.
   2658     for (int64 i = 0; i < iota_size; ++i) {
   2659       // static_cast is required for Eigen::half (F16).
   2660       data[i] = static_cast<NativeT>(i);
   2661     }
   2662     auto result = LiteralUtil::CreateR1<NativeT>(data);
   2663 
   2664     if (iota->shape().rank() > 1) {
   2665       TF_ASSIGN_OR_RETURN(
   2666           parent_->evaluated_[iota],
   2667           result.Broadcast(iota->shape(), {iota->iota_dimension()}));
   2668     } else {
   2669       TF_RET_CHECK(iota->shape().rank() == 1);
   2670       parent_->evaluated_[iota] = std::move(result);
   2671     }
   2672 
   2673     return Status::OK();
   2674   }
   2675   template <
   2676       typename NativeT,
   2677       typename std::enable_if<
   2678           !(std::is_same<NativeT, bfloat16>::value ||
   2679             std::is_same<NativeT, Eigen::half>::value ||
   2680             std::is_integral<NativeT>::value || is_complex_t<NativeT>::value ||
   2681             std::is_floating_point<NativeT>::value)>::type* = nullptr>
   2682   Status HandleIota(HloInstruction* iota) {
   2683     return UnsupportedTypeError(iota);
   2684   }
   2685   Status HandleIota(HloInstruction* iota) override {
   2686     return HandleIota<ReturnT>(iota);
   2687   }
   2688 
   2689   template <typename NativeT,
   2690             typename std::enable_if<
   2691                 !(std::is_integral<NativeT>::value ||
   2692                   std::is_floating_point<NativeT>::value)>::type* = nullptr>
   2693   Status HandleRng(HloInstruction* random) {
   2694     return UnsupportedTypeError(random);
   2695   }
   2696   template <typename NativeT,
   2697             typename std::enable_if<
   2698                 (std::is_floating_point<NativeT>::value)>::type* = nullptr>
   2699   Status HandleRng(HloInstruction* random) {
   2700     RandomDistribution distribution = random->random_distribution();
   2701     const auto result_shape = random->shape();
   2702     Literal result(result_shape);
   2703 
   2704     switch (distribution) {
   2705       case RNG_UNIFORM: {
   2706         const Literal& low =
   2707             parent_->GetEvaluatedLiteralFor(random->operand(0));
   2708         const Literal& high =
   2709             parent_->GetEvaluatedLiteralFor(random->operand(1));
   2710 
   2711         // std::uniform_real_distribution(a, b) can sometimes return a value
   2712         // equal to b.  Unclear if this is a spec bug or an implementation bug
   2713         // or WAI [0] [1] [2].  Anyway for our purposes we want a half-open
   2714         // interval, so we have to re-sample if we get `b` out.
   2715         //
   2716         // [0] https://gcc.gnu.org/bugzilla/show_bug.cgi?id=63176
   2717         // [1] https://bugs.llvm.org/show_bug.cgi?id=18767
   2718         // [2] http://open-std.org/JTC1/SC22/WG21/docs/lwg-active.html#2524
   2719         auto low_val = low.Get<NativeT>({});
   2720         auto high_val = high.Get<NativeT>({});
   2721         std::uniform_real_distribution<NativeT> generator(low_val, high_val);
   2722         TF_RETURN_IF_ERROR(
   2723             result.Populate<NativeT>([&](absl::Span<const int64> /*indexes*/) {
   2724               while (true) {
   2725                 NativeT v = generator(parent_->engine_);
   2726                 if (v != high_val) {
   2727                   return v;
   2728                 }
   2729               }
   2730             }));
   2731         break;
   2732       }
   2733       case RNG_NORMAL: {
   2734         const Literal& mean =
   2735             parent_->GetEvaluatedLiteralFor(random->operand(0));
   2736         const Literal& stddev =
   2737             parent_->GetEvaluatedLiteralFor(random->operand(1));
   2738 
   2739         std::normal_distribution<NativeT> generator(mean.Get<NativeT>({}),
   2740                                                     stddev.Get<NativeT>({}));
   2741 
   2742         TF_RETURN_IF_ERROR(
   2743             result.Populate<NativeT>([&](absl::Span<const int64> /*indexes*/) {
   2744               return generator(parent_->engine_);
   2745             }));
   2746         break;
   2747       }
   2748       default:
   2749         return UnimplementedStrCat("The distribution ",
   2750                                    RandomDistribution_Name(distribution),
   2751                                    " is not implemented.");
   2752     }
   2753     parent_->evaluated_[random] = std::move(result);
   2754     return Status::OK();
   2755   }
   2756   template <typename NativeT,
   2757             typename std::enable_if<(std::is_integral<NativeT>::value)>::type* =
   2758                 nullptr>
   2759   Status HandleRng(HloInstruction* random) {
   2760     RandomDistribution distribution = random->random_distribution();
   2761     const auto result_shape = random->shape();
   2762     Literal result(result_shape);
   2763 
   2764     switch (distribution) {
   2765       case RNG_UNIFORM: {
   2766         const Literal& low =
   2767             parent_->GetEvaluatedLiteralFor(random->operand(0));
   2768         const Literal& high =
   2769             parent_->GetEvaluatedLiteralFor(random->operand(1));
   2770 
   2771         // Note std::uniform_int_distribution assumes interval is closed, i.e.,
   2772         // [low, high], but we want [low, high) instead. Hence high-1 is used as
   2773         // the upper range.
   2774         std::uniform_int_distribution<int64> generator(
   2775             low.Get<NativeT>({}), high.Get<NativeT>({}) - 1);
   2776 
   2777         TF_RETURN_IF_ERROR(
   2778             result.Populate<NativeT>([&](absl::Span<const int64> /*indexes*/) {
   2779               return static_cast<NativeT>(generator(parent_->engine_));
   2780             }));
   2781         break;
   2782       }
   2783       case RNG_NORMAL: {
   2784         return Unimplemented(
   2785             "Normal distribution is not supported for integral types.");
   2786       }
   2787       default:
   2788         return UnimplementedStrCat("The distribution ",
   2789                                    RandomDistribution_Name(distribution),
   2790                                    " is not implemented.");
   2791     }
   2792     parent_->evaluated_[random] = std::move(result);
   2793     return Status::OK();
   2794   }
   2795   Status HandleRng(HloInstruction* random) override {
   2796     return HandleRng<ReturnT>(random);
   2797   }
   2798 
   2799  private:
   2800   // Creates a vector of multipliers which can be used to create a linear index
   2801   // into shape.
   2802   //
   2803   // Given the multidimensional index {i1, ..., iN} and
   2804   // M = MakeDimMultipliers(shape), the corresponding linear index LI is simply
   2805   //
   2806   //   LI = i1 * M[1] + i2 * M[2] + ... + iN * M[N].
   2807   //
   2808   // This lets you calculate LI given the multidimensional indices in any order.
   2809   static DimensionVector MakeDimMultipliers(const Shape& shape) {
   2810     DimensionVector v(shape.rank());
   2811     int64 scale = 1;
   2812     for (auto dim : LayoutUtil::MinorToMajor(shape)) {
   2813       v[dim] = scale;
   2814       scale *= shape.dimensions(dim);
   2815     }
   2816     return v;
   2817   }
   2818 
   2819   // For one particular placement of a window in a base shape (the placement is
   2820   // represented as `window_count_index`), iterates inside the window.
   2821   // Translates the window index into base index. If the base index is within
   2822   // bound, call `f` with the base index.
   2823   static void IterateThroughWindow(
   2824       const Shape& window_shape, const Window& window, const Shape& base_shape,
   2825       const absl::Span<const int64>& window_count_index,
   2826       const std::function<void(const std::vector<int64>&)>& f) {
   2827     const int64 rank = base_shape.rank();
   2828     DimensionVector window_index(rank);
   2829     std::fill(window_index.begin(), window_index.end(), 0);
   2830     do {
   2831       std::vector<int64> base_index(rank);
   2832       bool out_of_bound = false;
   2833       for (int64 i = 0; i < rank; ++i) {
   2834         base_index[i] =
   2835             window_count_index[i] * window.dimensions(i).stride() +
   2836             window_index[i] * window.dimensions(i).window_dilation() -
   2837             window.dimensions(i).padding_low();
   2838         // We are not in the base area if the dilation placed us out of bounds.
   2839         if (base_index[i] % window.dimensions(i).base_dilation() != 0) {
   2840           out_of_bound = true;
   2841           break;
   2842         }
   2843         // Apply the dilation to the base area.
   2844         base_index[i] /= window.dimensions(i).base_dilation();
   2845         if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) {
   2846           out_of_bound = true;
   2847           break;
   2848         }
   2849       }
   2850       if (!out_of_bound) {
   2851         f(base_index);
   2852       }
   2853     } while (
   2854         IndexUtil::BumpIndices(window_shape, absl::MakeSpan(window_index)));
   2855   }
   2856 
   2857   template <typename IndexT>
   2858   StatusOr<Literal> DynamicSlice(
   2859       const Literal& operand_literal,
   2860       absl::Span<HloInstruction* const> start_indices,
   2861       const Shape& result_shape) {
   2862     std::vector<int64> start;
   2863 
   2864     for (HloInstruction* index : start_indices) {
   2865       start.push_back(
   2866           parent_->GetEvaluatedLiteralFor(index).GetFirstElement<IndexT>());
   2867     }
   2868 
   2869     // Clamp the start indices so the slice is in-bounds w.r.t the operand.
   2870     for (int64 i = 0; i < start.size(); ++i) {
   2871       start[i] = std::min<int64>(
   2872           std::max(int64{0}, start[i]),
   2873           operand_literal.shape().dimensions(i) - result_shape.dimensions(i));
   2874     }
   2875 
   2876     std::vector<int64> operand_indices(start.size());
   2877     Literal result(result_shape);
   2878     TF_RETURN_IF_ERROR(
   2879         result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
   2880           for (int64 i = 0; i < operand_indices.size(); ++i) {
   2881             CHECK_GE(multi_index[i] + start[i], 0);
   2882             operand_indices[i] = multi_index[i] + start[i];
   2883           }
   2884 
   2885           auto result = operand_literal.Get<ReturnT>(operand_indices);
   2886           return result;
   2887         }));
   2888 
   2889     return std::move(result);
   2890   }
   2891 
   2892   template <typename IndexT>
   2893   StatusOr<Literal> DynamicUpdateSlice(
   2894       const Literal& operand_literal, const Literal& update_literal,
   2895       absl::Span<HloInstruction* const> start_indices) {
   2896     auto result = operand_literal.Clone();
   2897     const auto rank = result.shape().rank();
   2898     std::vector<int64> start;
   2899     for (HloInstruction* index : start_indices) {
   2900       start.push_back(
   2901           parent_->GetEvaluatedLiteralFor(index).GetFirstElement<IndexT>());
   2902     }
   2903 
   2904     // Clamp the update start indices so the slice is in-bounds w.r.t the
   2905     // operand.
   2906     for (int64 i = 0; i < rank; ++i) {
   2907       start[i] = std::min<int64>(
   2908           std::max<int64>(0, start[i]),
   2909           result.shape().dimensions(i) - update_literal.shape().dimensions(i));
   2910     }
   2911     std::vector<int64> result_index(rank, 0);
   2912 
   2913     auto func = [&](absl::Span<const int64> update_index) {
   2914       std::transform(update_index.begin(), update_index.end(), start.begin(),
   2915                      result_index.begin(), std::plus<int64>());
   2916       result.Set<ReturnT>(result_index,
   2917                           update_literal.Get<ReturnT>(update_index));
   2918       return true;
   2919     };
   2920 
   2921     std::vector<int64> base(update_literal.shape().dimensions_size(), 0);
   2922     std::vector<int64> step(update_literal.shape().dimensions_size(), 1);
   2923     ShapeUtil::ForEachIndex(update_literal.shape(), base,
   2924                             AsInt64Slice(update_literal.shape().dimensions()),
   2925                             step, func);
   2926 
   2927     return std::move(result);
   2928   }
   2929 
   2930   StatusOr<Literal> ElementWiseUnaryOp(
   2931       HloInstruction* instruction,
   2932       const std::function<ElementwiseT(ElementwiseT)>& unary_op) {
   2933     const Literal& operand_literal =
   2934         parent_->GetEvaluatedLiteralFor(instruction->operand(0));
   2935     TF_ASSIGN_OR_RETURN(
   2936         auto result_literal,
   2937         (HloEvaluator::ElementWiseUnaryOpImpl<ReturnT, ReturnT>(
   2938             instruction, ConvertUnaryFunction(unary_op), operand_literal)));
   2939 
   2940     return std::move(result_literal);
   2941   }
   2942 
   2943   StatusOr<Literal> ElementWiseBinaryOp(
   2944       HloInstruction* instruction,
   2945       const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>&
   2946           binary_op) {
   2947     const auto shape = instruction->shape();
   2948     const auto* lhs = instruction->operand(0);
   2949     const auto* rhs = instruction->operand(1);
   2950     TF_RET_CHECK(ShapeUtil::SameDimensions(shape, rhs->shape()));
   2951     TF_RET_CHECK(ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()));
   2952 
   2953     const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
   2954     const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
   2955 
   2956     Literal result(shape);
   2957 
   2958     TF_RETURN_IF_ERROR(
   2959         result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
   2960           return ConvertBinaryFunction(binary_op)(
   2961               lhs_literal.Get<ReturnT>(multi_index),
   2962               rhs_literal.Get<ReturnT>(multi_index));
   2963         }));
   2964     return std::move(result);
   2965   }
   2966 
   2967   template <typename LhsType, typename RhsType, typename EhsType>
   2968   StatusOr<Literal> ElementwiseTernaryOp(
   2969       HloInstruction* instruction,
   2970       const std::function<ReturnT(LhsType, RhsType, EhsType)>& ternary_op) {
   2971     const auto shape = instruction->shape();
   2972     const auto* lhs = instruction->operand(0);
   2973     const auto* rhs = instruction->operand(1);
   2974     const auto* ehs = instruction->operand(2);
   2975     TF_RET_CHECK(ShapeUtil::SameDimensions(shape, lhs->shape()));
   2976     TF_RET_CHECK(ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()));
   2977     TF_RET_CHECK(ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()));
   2978 
   2979     const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
   2980     const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
   2981     const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs);
   2982 
   2983     Literal result(shape);
   2984 
   2985     TF_RETURN_IF_ERROR(
   2986         result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
   2987           return ternary_op(lhs_literal.Get<LhsType>(multi_index),
   2988                             rhs_literal.Get<RhsType>(multi_index),
   2989                             ehs_literal.Get<EhsType>(multi_index));
   2990         }));
   2991 
   2992     return std::move(result);
   2993   }
   2994 
   2995   template <typename NativeT>
   2996   static bool IsShiftOutOfBounds(NativeT rhs) {
   2997     typedef typename std::make_unsigned<NativeT>::type UnsignedT;
   2998     UnsignedT lhs_size_unsigned = sizeof(NativeT) * CHAR_BIT;
   2999     UnsignedT rhs_unsigned = static_cast<UnsignedT>(rhs);
   3000     return rhs_unsigned >= lhs_size_unsigned;
   3001   }
   3002 
   3003   HloEvaluator* parent_;
   3004 };
   3005 
   3006 // These extern templates prevent users of this class from implicitly
   3007 // instantiating it.  We explicitly instantiate this class in the various
   3008 // hlo_evaluator_typed_visitor*.cc files.
   3009 extern template class HloEvaluatorTypedVisitor<bool>;
   3010 extern template class HloEvaluatorTypedVisitor<uint8>;
   3011 extern template class HloEvaluatorTypedVisitor<uint32>;
   3012 extern template class HloEvaluatorTypedVisitor<uint64>;
   3013 extern template class HloEvaluatorTypedVisitor<int8>;
   3014 extern template class HloEvaluatorTypedVisitor<int32>;
   3015 extern template class HloEvaluatorTypedVisitor<int64>;
   3016 extern template class HloEvaluatorTypedVisitor<Eigen::half, float>;
   3017 extern template class HloEvaluatorTypedVisitor<float>;
   3018 extern template class HloEvaluatorTypedVisitor<double>;
   3019 extern template class HloEvaluatorTypedVisitor<complex64>;
   3020 extern template class HloEvaluatorTypedVisitor<complex128>;
   3021 extern template class HloEvaluatorTypedVisitor<bfloat16, float>;
   3022 
   3023 }  // namespace xla
   3024 
   3025 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
   3026