Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
     16 
     17 #include <algorithm>
     18 #include <cmath>
     19 #include <cstdlib>
     20 #include <functional>
     21 #include <iterator>
     22 #include <string>
     23 #include <type_traits>
     24 #include <vector>
     25 
     26 #include "absl/algorithm/container.h"
     27 #include "absl/container/inlined_vector.h"
     28 #include "absl/memory/memory.h"
     29 #include "absl/strings/string_view.h"
     30 #include "tensorflow/compiler/xla/index_util.h"
     31 #include "tensorflow/compiler/xla/layout_util.h"
     32 #include "tensorflow/compiler/xla/literal_util.h"
     33 #include "tensorflow/compiler/xla/map_util.h"
     34 #include "tensorflow/compiler/xla/primitive_util.h"
     35 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
     36 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
     37 #include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
     38 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     39 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     40 #include "tensorflow/compiler/xla/service/hlo_query.h"
     41 #include "tensorflow/compiler/xla/service/shape_inference.h"
     42 #include "tensorflow/compiler/xla/shape_util.h"
     43 #include "tensorflow/compiler/xla/statusor.h"
     44 #include "tensorflow/compiler/xla/types.h"
     45 #include "tensorflow/compiler/xla/util.h"
     46 #include "tensorflow/compiler/xla/window_util.h"
     47 #include "tensorflow/core/lib/core/bitmap.h"
     48 #include "tensorflow/core/lib/core/errors.h"
     49 #include "tensorflow/core/lib/core/status.h"
     50 #include "tensorflow/core/platform/logging.h"
     51 #include "tensorflow/core/platform/protobuf.h"
     52 #include "tensorflow/core/platform/types.h"
     53 
     54 namespace xla {
     55 
     56 namespace {
     57 
     58 template <typename OperandT>
     59 StatusOr<Literal> Compare(const Shape& shape, ComparisonDirection direction,
     60                           LiteralSlice lhs_literal, LiteralSlice rhs_literal) {
     61   std::function<bool(OperandT, OperandT)> compare_op;
     62   switch (direction) {
     63     case ComparisonDirection::kEq:
     64       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
     65         return lhs_el == rhs_el;
     66       };
     67       break;
     68     case ComparisonDirection::kNe:
     69       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
     70         return lhs_el != rhs_el;
     71       };
     72       break;
     73     case ComparisonDirection::kGe:
     74       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
     75         return lhs_el >= rhs_el;
     76       };
     77       break;
     78     case ComparisonDirection::kGt:
     79       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
     80         return lhs_el > rhs_el;
     81       };
     82       break;
     83     case ComparisonDirection::kLe:
     84       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
     85         return lhs_el <= rhs_el;
     86       };
     87       break;
     88     case ComparisonDirection::kLt:
     89       compare_op = [](OperandT lhs_el, OperandT rhs_el) {
     90         return lhs_el < rhs_el;
     91       };
     92       break;
     93   }
     94 
     95   Literal result(shape);
     96   TF_RETURN_IF_ERROR(
     97       result.Populate<bool>([&](absl::Span<const int64> multi_index) {
     98         return compare_op(lhs_literal.Get<OperandT>(multi_index),
     99                           rhs_literal.Get<OperandT>(multi_index));
    100       }));
    101 
    102   return std::move(result);
    103 }
    104 
    105 template <>
    106 StatusOr<Literal> Compare<complex64>(const Shape& shape,
    107                                      ComparisonDirection direction,
    108                                      LiteralSlice lhs_literal,
    109                                      LiteralSlice rhs_literal) {
    110   std::function<bool(complex64, complex64)> compare_op;
    111   switch (direction) {
    112     case ComparisonDirection::kEq:
    113       compare_op = [](complex64 lhs_el, complex64 rhs_el) {
    114         return lhs_el == rhs_el;
    115       };
    116       break;
    117     case ComparisonDirection::kNe:
    118       compare_op = [](complex64 lhs_el, complex64 rhs_el) {
    119         return lhs_el != rhs_el;
    120       };
    121       break;
    122     default:
    123       LOG(FATAL) << "unhandled direction for conversion to Comparison: "
    124                  << ComparisonDirectionToString(direction);
    125   }
    126 
    127   Literal result(shape);
    128   TF_RETURN_IF_ERROR(
    129       result.Populate<bool>([&](absl::Span<const int64> multi_index) {
    130         return compare_op(lhs_literal.Get<complex64>(multi_index),
    131                           rhs_literal.Get<complex64>(multi_index));
    132       }));
    133 
    134   return std::move(result);
    135 }
    136 
    137 template <>
    138 StatusOr<Literal> Compare<complex128>(const Shape& shape,
    139                                       ComparisonDirection direction,
    140                                       LiteralSlice lhs_literal,
    141                                       LiteralSlice rhs_literal) {
    142   std::function<bool(complex128, complex128)> compare_op;
    143   switch (direction) {
    144     case ComparisonDirection::kEq:
    145       compare_op = [](complex128 lhs_el, complex128 rhs_el) {
    146         return lhs_el == rhs_el;
    147       };
    148       break;
    149     case ComparisonDirection::kNe:
    150       compare_op = [](complex128 lhs_el, complex128 rhs_el) {
    151         return lhs_el != rhs_el;
    152       };
    153       break;
    154     default:
    155       LOG(FATAL) << "unhandled direction for conversion to Comparison: "
    156                  << ComparisonDirectionToString(direction);
    157   }
    158 
    159   Literal result(shape);
    160   TF_RETURN_IF_ERROR(
    161       result.Populate<bool>([&](absl::Span<const int64> multi_index) {
    162         return compare_op(lhs_literal.Get<complex128>(multi_index),
    163                           rhs_literal.Get<complex128>(multi_index));
    164       }));
    165 
    166   return std::move(result);
    167 }
    168 
    169 }  // namespace
    170 
    171 // Note that unsupported types by the typed visitor does not necessarily imply
    172 // the non-typed HloEvaluator (parent evaluator) would not support them either
    173 // in the type-agnostic handler. For e.g., HandleGetTupleElement in the parent
    174 // type-agnostic evaluator will be able to accept Tuple primitive type, whereas
    175 // HloEvaluatorTypedVisitor cannot.
    176 HloEvaluator::HloEvaluator(int64 max_loop_iterations)
    177     : max_loop_iterations_(max_loop_iterations) {
    178   typed_visitors_[PRED] =
    179       absl::make_unique<HloEvaluatorTypedVisitor<bool>>(this);
    180   typed_visitors_[U8] =
    181       absl::make_unique<HloEvaluatorTypedVisitor<uint8>>(this);
    182   typed_visitors_[U16] =
    183       absl::make_unique<HloEvaluatorTypedVisitor<uint16>>(this);
    184   typed_visitors_[U32] =
    185       absl::make_unique<HloEvaluatorTypedVisitor<uint32>>(this);
    186   typed_visitors_[U64] =
    187       absl::make_unique<HloEvaluatorTypedVisitor<uint64>>(this);
    188   typed_visitors_[S8] = absl::make_unique<HloEvaluatorTypedVisitor<int8>>(this);
    189   typed_visitors_[S16] =
    190       absl::make_unique<HloEvaluatorTypedVisitor<int16>>(this);
    191   typed_visitors_[S32] =
    192       absl::make_unique<HloEvaluatorTypedVisitor<int32>>(this);
    193   typed_visitors_[S64] =
    194       absl::make_unique<HloEvaluatorTypedVisitor<int64>>(this);
    195   typed_visitors_[F16] =
    196       absl::make_unique<HloEvaluatorTypedVisitor<Eigen::half, float>>(this);
    197   typed_visitors_[F32] =
    198       absl::make_unique<HloEvaluatorTypedVisitor<float>>(this);
    199   typed_visitors_[F64] =
    200       absl::make_unique<HloEvaluatorTypedVisitor<double>>(this);
    201   typed_visitors_[C64] =
    202       absl::make_unique<HloEvaluatorTypedVisitor<complex64>>(this);
    203   typed_visitors_[C128] =
    204       absl::make_unique<HloEvaluatorTypedVisitor<complex128>>(this);
    205 
    206   // Most of the evaluator computations we use don't support BF16 (e.g.,
    207   // std::ceil, std::tanh). To make evaluator work with BF16, we set all
    208   // elementwise computations to be done in F32 and do BF16<->F32 conversion
    209   // around the input and the output of the computations.
    210   typed_visitors_[BF16] =
    211       absl::make_unique<HloEvaluatorTypedVisitor<bfloat16, float>>(this);
    212 
    213   typed_visitors_[TUPLE] =
    214       absl::make_unique<FunctionVisitor>([](HloInstruction*) {
    215         return Unimplemented(
    216             "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE.");
    217       });
    218   typed_visitors_[OPAQUE] =
    219       absl::make_unique<FunctionVisitor>([](HloInstruction*) {
    220         return Unimplemented(
    221             "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE.");
    222       });
    223   typed_visitors_[TOKEN] =
    224       absl::make_unique<FunctionVisitor>([](HloInstruction*) {
    225         return Unimplemented(
    226             "HloEvaluatorTypedVisitor: unhandled primitive type: TOKEN.");
    227       });
    228 }
    229 
    230 StatusOr<Literal> HloEvaluator::Evaluate(
    231     const HloComputation& computation,
    232     absl::Span<const Literal* const> arg_literals) {
    233   CHECK(computation.parent() != nullptr);
    234   XLA_VLOG_LINES(
    235       2, "HloEvaluator::Evaluate computation:\n" + computation.ToString());
    236 
    237   if (arg_literals.size() != computation.num_parameters()) {
    238     return InvalidArgument(
    239         "Expected %d argument%s, but got %d.", computation.num_parameters(),
    240         computation.num_parameters() == 1 ? "" : "s", arg_literals.size());
    241   }
    242   for (int64 i = 0; i < arg_literals.size(); ++i) {
    243     const auto& computation_shape =
    244         computation.parameter_instruction(i)->shape();
    245     const auto& arg_shape = arg_literals[i]->shape();
    246     if (!ShapeUtil::Equal(computation_shape, arg_shape)) {
    247       return InvalidArgument(
    248           "Shape mismatch at parameter %d. Computation expected %s, but arg "
    249           "was %s.",
    250           i, ShapeUtil::HumanStringWithLayout(computation_shape),
    251           ShapeUtil::HumanString(arg_shape));
    252     }
    253   }
    254 
    255   evaluated_.clear();
    256   arg_literals_.clear();
    257   for (const auto& literal_ptr : arg_literals) {
    258     arg_literals_.push_back(&*literal_ptr);
    259   }
    260 
    261   // Re-seed RNG, either from the configuration's seed or a monotonic
    262   // per-evaluator seed (which prevents two evaluators from returning the same
    263   // random sequence).
    264   if (computation.parent()->config().seed()) {
    265     seed_ = computation.parent()->config().seed();
    266   } else {
    267     // Start global_seed at a (true) random value.
    268     static std::atomic<uint64> global_seed{std::random_device()()};
    269     seed_ = global_seed.fetch_add(1);
    270   }
    271   engine_.seed(seed_);
    272 
    273   TF_RETURN_IF_ERROR(computation.Accept(this));
    274   return GetEvaluatedLiteralFor(computation.root_instruction()).Clone();
    275 }
    276 
    277 StatusOr<Literal> HloEvaluator::Evaluate(HloInstruction* instruction) {
    278   if (instruction->opcode() == HloOpcode::kParameter) {
    279     return tensorflow::errors::FailedPrecondition(
    280         "Cannot evaluate a parameter.");
    281   }
    282   if (!hlo_query::AllOperandsAreConstants(*instruction)) {
    283     return tensorflow::errors::FailedPrecondition(
    284         "Not all operands are constants.");
    285   }
    286 
    287   arg_literals_.clear();
    288   evaluated_.clear();
    289 
    290   TF_RETURN_IF_ERROR(Preprocess(instruction));
    291   TF_RETURN_IF_ERROR(instruction->Visit(this));
    292   TF_RETURN_IF_ERROR(Postprocess(instruction));
    293   return GetEvaluatedLiteralFor(instruction).Clone();
    294 }
    295 
    296 bool HloEvaluator::TryEvaluate(HloInstruction* instruction, Literal* result) {
    297   CHECK(result != nullptr);
    298   auto result_or = Evaluate(instruction);
    299   if (!result_or.ok()) {
    300     VLOG(1) << "TryEvaluate failed:" << result_or.status();
    301     return false;
    302   }
    303 
    304   *result = result_or.ConsumeValueOrDie();
    305   return true;
    306 }
    307 
    308 StatusOr<Literal> HloEvaluator::EvaluateWithSubstitutions(
    309     const HloInstruction* instruction,
    310     const std::unordered_map<const HloInstruction*, const Literal*>&
    311         substitutions) {
    312   std::vector<std::unique_ptr<HloInstruction>> owned_operands;
    313   for (const HloInstruction* operand : instruction->operands()) {
    314     auto it = substitutions.find(operand);
    315     if (it == substitutions.end()) {
    316       owned_operands.push_back(operand->Clone());
    317     } else {
    318       owned_operands.push_back(
    319           HloInstruction::CreateConstant(it->second->Clone()));
    320     }
    321   }
    322 
    323   std::vector<HloInstruction*> operands;
    324   operands.reserve(owned_operands.size());
    325   for (auto& operand : owned_operands) {
    326     operands.push_back(operand.get());
    327   }
    328 
    329   std::unique_ptr<HloInstruction> cloned_instruction =
    330       instruction->CloneWithNewOperands(instruction->shape(), operands);
    331   auto result = Evaluate(cloned_instruction.get());
    332 
    333   return result;
    334 }
    335 
    336 StatusOr<Literal> HloEvaluator::EvaluateElementwiseBinaryOp(
    337     HloOpcode opcode, const Literal& lhs, const Literal& rhs) {
    338   std::unique_ptr<HloInstruction> lhs_instr =
    339       HloInstruction::CreateConstant(lhs.Clone());
    340   std::unique_ptr<HloInstruction> rhs_instr =
    341       HloInstruction::CreateConstant(rhs.Clone());
    342 
    343   std::unique_ptr<HloInstruction> cloned_instruction =
    344       HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(),
    345                                    rhs_instr.get());
    346   auto result = Evaluate(cloned_instruction.get());
    347 
    348   return result;
    349 }
    350 
    351 StatusOr<Literal> HloEvaluator::EvaluateElementwiseUnaryOp(
    352     HloOpcode opcode, const Literal& operand) {
    353   std::unique_ptr<HloInstruction> operand_instr =
    354       HloInstruction::CreateConstant(operand.Clone());
    355 
    356   std::unique_ptr<HloInstruction> cloned_instruction =
    357       HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get());
    358   auto result = Evaluate(cloned_instruction.get());
    359 
    360   return result;
    361 }
    362 
    363 StatusOr<Literal> HloEvaluator::EvaluateDotOp(
    364     const DotDimensionNumbers& dim_numbers,
    365     const PrecisionConfig& precision_config, const Literal& lhs,
    366     const Literal& rhs) {
    367   std::unique_ptr<HloInstruction> lhs_instr =
    368       HloInstruction::CreateConstant(lhs.Clone());
    369   std::unique_ptr<HloInstruction> rhs_instr =
    370       HloInstruction::CreateConstant(rhs.Clone());
    371 
    372   TF_ASSIGN_OR_RETURN(
    373       Shape dot_shape,
    374       ShapeInference::InferDotOpShape(lhs.shape(), rhs.shape(), dim_numbers));
    375 
    376   std::unique_ptr<HloInstruction> cloned_instruction =
    377       HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(),
    378                                 dim_numbers, precision_config);
    379   return Evaluate(cloned_instruction.get());
    380 }
    381 
    382 Status HloEvaluator::HandleBitcast(HloInstruction* bitcast) {
    383   const Literal& operand_literal = GetEvaluatedLiteralFor(bitcast->operand(0));
    384   Literal result(bitcast->shape());
    385   TF_RET_CHECK(operand_literal.size_bytes() == result.size_bytes());
    386   memcpy(result.untyped_data(), operand_literal.untyped_data(),
    387          operand_literal.size_bytes());
    388   evaluated_[bitcast] = std::move(result);
    389   return Status::OK();
    390 }
    391 
    392 Status HloEvaluator::HandleGetDimensionSize(
    393     HloInstruction* get_dimension_size) {
    394   HloInstruction* operand = get_dimension_size->mutable_operand(0);
    395   int64 dim = get_dimension_size->dimension();
    396   if (dynamic_dimension_inference_ == nullptr) {
    397     return InvalidArgument(
    398         "Evaluator cannot evaluate get_dimension_size without "
    399         "set_dynamic_dimension_inference.");
    400   }
    401   HloInstruction* dynamic_size =
    402       dynamic_dimension_inference_->GetDynamicSize(operand, {}, dim);
    403   if (dynamic_size != nullptr) {
    404     evaluated_[get_dimension_size] =
    405         GetEvaluatedLiteralFor(dynamic_size).Clone();
    406     return Status::OK();
    407   }
    408 
    409   const Shape& shape = get_dimension_size->operand(0)->shape();
    410   Literal output(ShapeUtil::MakeShape(U32, {}));
    411   output.PopulateWithValue(
    412       static_cast<uint32>(shape.dimensions(get_dimension_size->dimension())));
    413   evaluated_[get_dimension_size] = std::move(output);
    414   return Status::OK();
    415 }
    416 
    417 Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
    418   // Nothing to do other than sanity checks. Parameters' values are stored in
    419   // arg_literals_.
    420   CHECK_LT(parameter->parameter_number(), arg_literals_.size());
    421 
    422 #ifndef NDEBUG
    423   const Literal* input_literal = arg_literals_[parameter->parameter_number()];
    424   VLOG(2) << "Parameter evaluated to: " << input_literal->ToString();
    425   DCHECK(ShapeUtil::Equal(parameter->shape(), input_literal->shape()))
    426       << "parameter shape is: " << ShapeUtil::HumanString(parameter->shape())
    427       << ", but input literal shape is: "
    428       << ShapeUtil::HumanString(input_literal->shape());
    429 #endif
    430 
    431   return Status::OK();
    432 }
    433 
    434 Status HloEvaluator::HandleConstant(HloInstruction*) { return Status::OK(); }
    435 
    436 Status HloEvaluator::HandleReshape(HloInstruction* reshape) {
    437   TF_ASSIGN_OR_RETURN(
    438       evaluated_[reshape],
    439       GetEvaluatedLiteralFor(reshape->operand(0))
    440           .Reshape(AsInt64Slice(reshape->shape().dimensions())));
    441   return Status::OK();
    442 }
    443 
    444 Status HloEvaluator::HandleTranspose(HloInstruction* transpose) {
    445   evaluated_[transpose] = GetEvaluatedLiteralFor(transpose->operand(0))
    446                               .Transpose(transpose->dimensions());
    447   return Status::OK();
    448 }
    449 
    450 Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
    451   absl::Span<HloInstruction* const> operands(concatenate->operands());
    452   // The result concatenate dimension is going to be the sum of all
    453   // concatenate dimensions of the operands taking part of the operation.
    454   const Shape& reference_shape = operands[0]->shape();
    455   CHECK(reference_shape.IsArray());
    456   const int64 rank = reference_shape.rank();
    457   const int64 concat_dim = concatenate->dimensions()[0];
    458   CHECK_GE(concat_dim, 0);
    459   CHECK_LT(concat_dim, rank);
    460 
    461   DimensionVector concat_dimensions(reference_shape.dimensions().begin(),
    462                                     reference_shape.dimensions().end());
    463 
    464   for (int64 i = 1; i < operands.size(); ++i) {
    465     const Shape& operand_shape = operands[i]->shape();
    466     CHECK(operand_shape.IsArray());
    467     // Accumulate the concat dimension from all tensors taking part to the
    468     // operation.
    469     concat_dimensions[concat_dim] +=
    470         ShapeUtil::GetDimension(operand_shape, concat_dim);
    471   }
    472 
    473   auto result_literal = LiteralUtil::CreateFromDimensions(
    474       reference_shape.element_type(), concat_dimensions);
    475   DimensionVector source_indices(rank, 0);
    476   DimensionVector dest_indices(concat_dimensions.size(), 0);
    477 
    478   for (auto operand : operands) {
    479     const Shape& operand_shape = operand->shape();
    480     TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
    481         GetEvaluatedLiteralFor(operand), source_indices, dest_indices,
    482         AsInt64Slice(operand_shape.dimensions())));
    483     dest_indices[concat_dim] +=
    484         ShapeUtil::GetDimension(operand_shape, concat_dim);
    485   }
    486 
    487   evaluated_[concatenate] = std::move(result_literal);
    488   return Status::OK();
    489 }
    490 
    491 Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) {
    492   auto operand = is_finite->operand(0);
    493   auto elem_ty = operand->shape().element_type();
    494   switch (elem_ty) {
    495     case PRED:
    496     case TUPLE:
    497     case OPAQUE:
    498     case TOKEN:
    499     case S8:
    500     case S16:
    501     case S32:
    502     case S64:
    503     case U8:
    504     case U16:
    505     case U32:
    506     case U64:
    507     case C64:
    508     case C128:
    509     // Explicitly enumerate all types in this switch so that when we add a new
    510     // type, we'll get a compile error here.
    511     case PRIMITIVE_TYPE_INVALID:
    512     case PrimitiveType_INT_MIN_SENTINEL_DO_NOT_USE_:
    513     case PrimitiveType_INT_MAX_SENTINEL_DO_NOT_USE_:
    514       return InvalidArgument(
    515           "expected element type in shape to be floating point, but "
    516           "got: %s",
    517           PrimitiveType_Name(elem_ty));
    518 
    519     case F16: {
    520       auto result_or = ElementWiseUnaryOpImpl<bool, Eigen::half>(
    521           is_finite,
    522           [](Eigen::half elem_operand) {
    523             return std::isfinite(static_cast<float>(elem_operand));
    524           },
    525           GetEvaluatedLiteralFor(operand));
    526       TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
    527       break;
    528     }
    529     case BF16: {
    530       auto result_or = ElementWiseUnaryOpImpl<bool, bfloat16>(
    531           is_finite,
    532           [](bfloat16 elem_operand) {
    533             return std::isfinite(static_cast<float>(elem_operand));
    534           },
    535           GetEvaluatedLiteralFor(operand));
    536       TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
    537       break;
    538     }
    539     case F32: {
    540       auto result_or = ElementWiseUnaryOpImpl<bool, float>(
    541           is_finite,
    542           [](float elem_operand) { return std::isfinite(elem_operand); },
    543           GetEvaluatedLiteralFor(operand));
    544       TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
    545       break;
    546     }
    547     case F64: {
    548       auto result_or = ElementWiseUnaryOpImpl<bool, double>(
    549           is_finite,
    550           [](double elem_operand) { return std::isfinite(elem_operand); },
    551           GetEvaluatedLiteralFor(operand));
    552       TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
    553       break;
    554     }
    555   }
    556 
    557   return Status::OK();
    558 }
    559 
    560 Status HloEvaluator::HandleReal(HloInstruction* real) {
    561   auto operand = real->operand(0);
    562   switch (operand->shape().element_type()) {
    563     case BF16: {
    564       auto result_or = ElementWiseUnaryOpImpl<bfloat16, bfloat16>(
    565           real, [](bfloat16 elem_operand) { return elem_operand; },
    566           GetEvaluatedLiteralFor(operand));
    567       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
    568       break;
    569     }
    570     case C64: {
    571       auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
    572           real, [](complex64 elem_operand) { return std::real(elem_operand); },
    573           GetEvaluatedLiteralFor(operand));
    574       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
    575       break;
    576     }
    577     case C128: {
    578       auto result_or = ElementWiseUnaryOpImpl<double, complex128>(
    579           real, [](complex128 elem_operand) { return std::real(elem_operand); },
    580           GetEvaluatedLiteralFor(operand));
    581       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
    582       break;
    583     }
    584     case F16: {
    585       auto result_or = ElementWiseUnaryOpImpl<Eigen::half, Eigen::half>(
    586           real, [](Eigen::half elem_operand) { return elem_operand; },
    587           GetEvaluatedLiteralFor(operand));
    588       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
    589       break;
    590     }
    591     case F32: {
    592       auto result_or = ElementWiseUnaryOpImpl<float, float>(
    593           real, [](float elem_operand) { return elem_operand; },
    594           GetEvaluatedLiteralFor(operand));
    595       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
    596       break;
    597     }
    598     case F64: {
    599       auto result_or = ElementWiseUnaryOpImpl<double, double>(
    600           real, [](double elem_operand) { return elem_operand; },
    601           GetEvaluatedLiteralFor(operand));
    602       TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
    603       break;
    604     }
    605     default:
    606       LOG(FATAL) << "HandleReal: unknown/unhandled primitive type: "
    607                  << PrimitiveType_Name(operand->shape().element_type());
    608   }
    609 
    610   return Status::OK();
    611 }
    612 
    613 Status HloEvaluator::HandleImag(HloInstruction* imag) {
    614   auto operand = imag->operand(0);
    615   switch (operand->shape().element_type()) {
    616     case C64: {
    617       auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
    618           imag, [](complex64 elem_operand) { return std::imag(elem_operand); },
    619           GetEvaluatedLiteralFor(imag->operand(0)));
    620 
    621       TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
    622       break;
    623     }
    624     case C128: {
    625       auto result_or = ElementWiseUnaryOpImpl<double, complex128>(
    626           imag, [](complex128 elem_operand) { return std::imag(elem_operand); },
    627           GetEvaluatedLiteralFor(imag->operand(0)));
    628 
    629       TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
    630       break;
    631     }
    632     default:
    633       LOG(FATAL) << "HandleImag: unknown/unhandled primitive type: "
    634                  << PrimitiveType_Name(operand->shape().element_type());
    635   }
    636 
    637   return Status::OK();
    638 }
    639 
    640 Status HloEvaluator::HandleComplex(HloInstruction* complex) {
    641   const Literal& real = GetEvaluatedLiteralFor(complex->operand(0));
    642   const Literal& imag = GetEvaluatedLiteralFor(complex->operand(1));
    643   TF_RET_CHECK(ShapeUtil::Compatible(real.shape(), imag.shape()));
    644 
    645   Literal result(complex->shape());
    646   switch (complex->shape().element_type()) {
    647     case C64: {
    648       TF_RETURN_IF_ERROR(
    649           result.Populate<complex64>([&](absl::Span<const int64> multi_index) {
    650             return std::complex<float>(real.Get<float>(multi_index),
    651                                        imag.Get<float>(multi_index));
    652           }));
    653       break;
    654     }
    655     case C128: {
    656       TF_RETURN_IF_ERROR(
    657           result.Populate<complex128>([&](absl::Span<const int64> multi_index) {
    658             return std::complex<float>(real.Get<double>(multi_index),
    659                                        imag.Get<double>(multi_index));
    660           }));
    661       break;
    662     }
    663     default:
    664       LOG(FATAL) << "HandleComplex: unknown/unhandled primitive type: "
    665                  << PrimitiveType_Name(complex->shape().element_type());
    666   }
    667 
    668   evaluated_[complex] = std::move(result);
    669   return Status::OK();
    670 }
    671 
    672 Status HloEvaluator::HandleCompare(HloInstruction* compare) {
    673   ComparisonDirection direction = compare->comparison_direction();
    674   auto lhs = compare->operand(0);
    675   auto rhs = compare->operand(1);
    676   DCHECK(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) &&
    677          ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()));
    678 
    679   TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type());
    680 
    681   const Literal& lhs_literal = GetEvaluatedLiteralFor(lhs);
    682   const Literal& rhs_literal = GetEvaluatedLiteralFor(rhs);
    683 
    684   // Note here we switch on the operand's type.
    685   switch (lhs->shape().element_type()) {
    686     case PRED: {
    687       TF_ASSIGN_OR_RETURN(
    688           evaluated_[compare],
    689           Compare<bool>(compare->shape(), direction, lhs_literal, rhs_literal));
    690     } break;
    691     case U8: {
    692       TF_ASSIGN_OR_RETURN(evaluated_[compare],
    693                           Compare<uint8>(compare->shape(), direction,
    694                                          lhs_literal, rhs_literal));
    695     } break;
    696     case U16: {
    697       TF_ASSIGN_OR_RETURN(evaluated_[compare],
    698                           Compare<uint16>(compare->shape(), direction,
    699                                           lhs_literal, rhs_literal));
    700     } break;
    701     case U32: {
    702       TF_ASSIGN_OR_RETURN(evaluated_[compare],
    703                           Compare<uint32>(compare->shape(), direction,
    704                                           lhs_literal, rhs_literal));
    705     } break;
    706     case U64: {
    707       TF_ASSIGN_OR_RETURN(evaluated_[compare],
    708                           Compare<uint64>(compare->shape(), direction,
    709                                           lhs_literal, rhs_literal));
    710     } break;
    711     case S8: {
    712       TF_ASSIGN_OR_RETURN(
    713           evaluated_[compare],
    714           Compare<int8>(compare->shape(), direction, lhs_literal, rhs_literal));
    715     } break;
    716     case S16: {
    717       TF_ASSIGN_OR_RETURN(evaluated_[compare],
    718                           Compare<int16>(compare->shape(), direction,
    719                                          lhs_literal, rhs_literal));
    720     } break;
    721     case S32: {
    722       TF_ASSIGN_OR_RETURN(evaluated_[compare],
    723                           Compare<int32>(compare->shape(), direction,
    724                                          lhs_literal, rhs_literal));
    725     } break;
    726     case S64: {
    727       TF_ASSIGN_OR_RETURN(evaluated_[compare],
    728                           Compare<int64>(compare->shape(), direction,
    729                                          lhs_literal, rhs_literal));
    730     } break;
    731     case F16: {
    732       TF_ASSIGN_OR_RETURN(
    733           evaluated_[compare],
    734           Compare<half>(compare->shape(), direction, lhs_literal, rhs_literal));
    735     } break;
    736     case BF16: {
    737       TF_ASSIGN_OR_RETURN(evaluated_[compare],
    738                           Compare<bfloat16>(compare->shape(), direction,
    739                                             lhs_literal, rhs_literal));
    740     } break;
    741     case F32: {
    742       TF_ASSIGN_OR_RETURN(evaluated_[compare],
    743                           Compare<float>(compare->shape(), direction,
    744                                          lhs_literal, rhs_literal));
    745     } break;
    746     case F64: {
    747       TF_ASSIGN_OR_RETURN(evaluated_[compare],
    748                           Compare<double>(compare->shape(), direction,
    749                                           lhs_literal, rhs_literal));
    750     } break;
    751     case C64: {
    752       TF_ASSIGN_OR_RETURN(evaluated_[compare],
    753                           Compare<complex64>(compare->shape(), direction,
    754                                              lhs_literal, rhs_literal));
    755     } break;
    756     case C128: {
    757       TF_ASSIGN_OR_RETURN(evaluated_[compare],
    758                           Compare<complex128>(compare->shape(), direction,
    759                                               lhs_literal, rhs_literal));
    760     } break;
    761     default:
    762       LOG(FATAL) << "HandleCompare: unknown primitive type: "
    763                  << PrimitiveType_Name(lhs->shape().element_type());
    764   }
    765 
    766   return Status::OK();
    767 }
    768 
    769 Status HloEvaluator::HandleTuple(HloInstruction* tuple) {
    770   std::vector<const Literal*> operand_literals;
    771   for (auto operand : tuple->operands()) {
    772     operand_literals.push_back(&GetEvaluatedLiteralFor(operand));
    773   }
    774 
    775   evaluated_[tuple] = LiteralUtil::MakeTuple(operand_literals);
    776   return Status::OK();
    777 }
    778 
    779 // Returns an ShapeUtil::IndexIterationSpace that iterates over the output batch
    780 // dimensions while keeping the rest of the output dimensions clamped to 0.
    781 ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices(
    782     const Shape& output_shape, const GatherDimensionNumbers& dim_numbers) {
    783   int64 output_rank = output_shape.dimensions_size();
    784   std::vector<int64> index_base(output_rank, 0);
    785   std::vector<int64> index_count;
    786   index_count.reserve(output_rank);
    787   for (int64 i = 0; i < output_rank; i++) {
    788     bool is_output_batch_dim =
    789         !absl::c_binary_search(dim_numbers.offset_dims(), i);
    790     index_count.push_back(is_output_batch_dim ? output_shape.dimensions(i) : 1);
    791   }
    792 
    793   return {std::move(index_base), std::move(index_count),
    794           std::vector<int64>(output_rank, 1)};
    795 }
    796 
    797 // Return an ShapeUtil::IndexIterationSpace that iterates over the output slice
    798 // dimensions while keeping the rest of the output dimensions clamped to 0.
    799 ShapeUtil::IndexIterationSpace IterationSpaceForOutputOffsetIndices(
    800     int64 output_rank, absl::Span<const int64> slice_sizes,
    801     const GatherDimensionNumbers& dim_numbers) {
    802   std::vector<int64> index_base(output_rank, 0);
    803   std::vector<int64> index_count(output_rank, 1);
    804   int64 slice_sizes_idx = 0;
    805   for (int64 i = 0; i < output_rank; i++) {
    806     bool is_output_window_dim =
    807         absl::c_binary_search(dim_numbers.offset_dims(), i);
    808     if (is_output_window_dim) {
    809       while (absl::c_binary_search(dim_numbers.collapsed_slice_dims(),
    810                                    slice_sizes_idx)) {
    811         slice_sizes_idx++;
    812       }
    813       index_count[i] = slice_sizes[slice_sizes_idx++];
    814     }
    815   }
    816 
    817   return {std::move(index_base), std::move(index_count),
    818           std::vector<int64>(output_rank, 1)};
    819 }
    820 
    821 // This functor computes the contribution of start_indices to an input index
    822 // corresponding to an output index.  That is, given an output index I, it picks
    823 // out the batch indices in I and uses them to look up a starting index, G, from
    824 // the start indices tensor, and expands G into the input space according to
    825 // start_index_map.
    826 class OutputBatchIndexToInputIndex {
    827  public:
    828   // The constructor does some setup work that is amortized across all
    829   // iterations.
    830   explicit OutputBatchIndexToInputIndex(
    831       const GatherDimensionNumbers* dim_numbers, const Shape& input_shape,
    832       const Shape& output_shape, const Literal* start_indices)
    833       : dim_numbers_(*dim_numbers), start_indices_(*start_indices) {
    834     for (int64 i = 0; i < output_shape.dimensions_size(); i++) {
    835       output_dim_is_batch_dims_.push_back(
    836           !absl::c_binary_search(dim_numbers_.offset_dims(), i));
    837     }
    838 
    839     for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
    840       int64 index_of_input_dim_in_index_vector =
    841           std::distance(dim_numbers_.start_index_map().begin(),
    842                         absl::c_find(dim_numbers_.start_index_map(), i));
    843       if (index_of_input_dim_in_index_vector ==
    844           dim_numbers_.start_index_map_size()) {
    845         input_dim_value_to_index_vector_.push_back(-1);
    846       } else {
    847         input_dim_value_to_index_vector_.push_back(
    848             index_of_input_dim_in_index_vector);
    849       }
    850     }
    851 
    852     index_vector_index_.resize(start_indices_.shape().dimensions_size());
    853     input_index_.resize(input_shape.dimensions_size());
    854     int64 index_vector_size =
    855         start_indices_.shape().dimensions(dim_numbers_.index_vector_dim());
    856     index_vector_.resize(index_vector_size);
    857   }
    858 
    859   // Returns the contribution of start_indices to the input index corresponding
    860   // to output_index.  See gather_inner_loop_body.
    861   //
    862   // This is conceptually  a stateless transformation from output_index to the
    863   // gather input index, but:
    864   //
    865   //  - Instead of allocating memory to represent the gather input index on
    866   //    every invocation we reuse the same storage for the result
    867   //    (input_index_), mutating it in place.
    868   //  - Instead of allocating buffers for temporary values like
    869   //    index_vector_index_ and index_vector on every invocation, we reuse the
    870   //    same storage for all invocations.
    871   //
    872   // This returns a Span into memory owned by the class.
    873   StatusOr<absl::Span<const int64>> operator()(
    874       absl::Span<const int64> output_index) {
    875     PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index);
    876     TF_RETURN_IF_ERROR(FetchIndexVector());
    877     PropagateIndexVectorToInputIndex();
    878     return absl::Span<const int64>(input_index_);
    879   }
    880 
    881  private:
    882   // Propagates the batch dimensions from the output index into
    883   // index_vector_index_ by mutating index_vector_index_ in place.  Does not
    884   // update the dim_numbers.index_vector_dim() dimension -- that's the dimension
    885   // we iterate over in FetchIndexVector.
    886   void PropagateOutputIndexGatherDimsToIndexVectorIndex(
    887       absl::Span<const int64> output_index) {
    888     int64 index_vector_index_i = 0;
    889     for (int64 i = 0, e = output_index.size(); i < e; i++) {
    890       if (!output_dim_is_batch_dims_[i]) {
    891         continue;
    892       }
    893 
    894       if (index_vector_index_i == dim_numbers_.index_vector_dim()) {
    895         index_vector_index_i++;
    896       }
    897 
    898       index_vector_index_[index_vector_index_i++] = output_index[i];
    899     }
    900   }
    901 
    902   // Populates index_vector_ by iterating over start_indices_ according to
    903   // index_vector_index_.
    904   Status FetchIndexVector() {
    905     int64 index_vector_dim = dim_numbers_.index_vector_dim();
    906     for (int64 i = 0, e = index_vector_.size(); i < e; i++) {
    907       index_vector_index_[index_vector_dim] = i;
    908       TF_ASSIGN_OR_RETURN(index_vector_[i],
    909                           start_indices_.GetIntegralAsS64(index_vector_index_));
    910     }
    911     return Status::OK();
    912   }
    913 
    914   // Populates input_index_.
    915   void PropagateIndexVectorToInputIndex() {
    916     for (int64 i = 0, e = input_index_.size(); i < e; i++) {
    917       if (input_dim_value_to_index_vector_[i] != -1) {
    918         input_index_[i] = index_vector_[input_dim_value_to_index_vector_[i]];
    919       }
    920 
    921       // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
    922       // remains 0, as set by the constructor.
    923     }
    924   }
    925 
    926   // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
    927   // the input index from the index vector.  See
    928   // PropagateIndexVectorToInputIndex.
    929   std::vector<int64> input_dim_value_to_index_vector_;
    930 
    931   // output_dim_is_batch_dims_[i] is true iff the output index i is a gather
    932   // dimension.
    933   std::vector<bool> output_dim_is_batch_dims_;
    934 
    935   // The buffer into which we construct an index into start_indices_ to fetch
    936   // the index vector.
    937   std::vector<int64> index_vector_index_;
    938 
    939   // The index vector fetched from start_indices_.
    940   std::vector<int64> index_vector_;
    941 
    942   // The result computed by this functor.  operator() returns a Span into
    943   // this vector.
    944   std::vector<int64> input_index_;
    945 
    946   const GatherDimensionNumbers& dim_numbers_;
    947   const Literal& start_indices_;
    948 };
    949 
    950 // This functor computes the contribution of the offset indices in an output
    951 // index to an input index.  That is, given an output index I it picks out the
    952 // output offset indices in I and expands it into an index into the input shape.
    953 class OutputOffsetIndexToInputIndex {
    954  public:
    955   // The constructor does some setup work that is amortized across all
    956   // iterations.
    957   explicit OutputOffsetIndexToInputIndex(
    958       const GatherDimensionNumbers& dim_numbers, const Shape& input_shape,
    959       const Shape& output_shape) {
    960     std::vector<int64> window_index_to_output_index;
    961     int64 output_index_count = 0;
    962     for (int64 i = 0; i < output_shape.dimensions_size(); i++) {
    963       if (absl::c_binary_search(dim_numbers.offset_dims(), i)) {
    964         window_index_to_output_index.push_back(output_index_count++);
    965       } else {
    966         output_index_count++;
    967       }
    968     }
    969 
    970     int64 window_dim_count = 0;
    971     for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
    972       if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
    973         input_dim_value_to_output_index_.push_back(-1);
    974       } else {
    975         input_dim_value_to_output_index_.push_back(
    976             window_index_to_output_index[window_dim_count++]);
    977       }
    978     }
    979 
    980     input_index_.resize(input_shape.dimensions_size());
    981   }
    982 
    983   // Returns the contribution of the window indices to the input index
    984   // corresponding to output_index.  See gather_inner_loop_body.
    985   //
    986   // This is conceptually a stateless transformation from output_index to the
    987   // window input index, but instead of allocating memory to represent the
    988   // gather input index on every invocation we reuse the same storage for the
    989   // result (input_index_), mutating it in place.
    990   //
    991   // This returns a Span into memory owned by the class.
    992   StatusOr<absl::Span<const int64>> operator()(
    993       absl::Span<const int64> output_index) {
    994     PropagateOutputIndexWindowDimsToInputIndex(output_index);
    995     return absl::Span<const int64>(input_index_);
    996   }
    997 
    998   // Returns for a given 'input_dim' the corresponding output dimension index,
    999   // or -1 if 'input_dim' is an elided window dimension.
   1000   int64 input_dim_value_to_output_index(int64 input_dim) {
   1001     return input_dim_value_to_output_index_[input_dim];
   1002   }
   1003 
   1004  private:
   1005   // Propagates window dimensions from the output index to input_index_ by
   1006   // mutating input_index_ in place.
   1007   void PropagateOutputIndexWindowDimsToInputIndex(
   1008       absl::Span<const int64> output_index) {
   1009     for (int64 i = 0, e = input_index_.size(); i < e; i++) {
   1010       if (input_dim_value_to_output_index_[i] != -1) {
   1011         input_index_[i] = output_index[input_dim_value_to_output_index_[i]];
   1012       }
   1013 
   1014       // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i]
   1015       // remains 0, as set by the constructor.
   1016     }
   1017   }
   1018 
   1019   // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
   1020   // the input index from the output index. See
   1021   // PropagateOutputIndexWindowDimsToInputIndex.
   1022   std::vector<int64> input_dim_value_to_output_index_;
   1023 
   1024   // The result computed by this functor.  operator() returns a Span into
   1025   // this vector.
   1026   std::vector<int64> input_index_;
   1027 };
   1028 
   1029 // Rehapes the gather indices input to have a trailing degenerate `1` dimension
   1030 // if necessary.  Hands over the ownership of the newly created literal (if
   1031 // there is one) to `reshaped_start_indices`.
   1032 static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
   1033     int64 index_vector_dim, const Literal& start_indices,
   1034     Literal* reshaped_start_indices) {
   1035   if (start_indices.shape().dimensions_size() != index_vector_dim) {
   1036     return std::cref(start_indices);
   1037   }
   1038 
   1039   std::vector<int64> new_shape(start_indices.shape().dimensions().begin(),
   1040                                start_indices.shape().dimensions().end());
   1041   new_shape.push_back(1);
   1042   TF_ASSIGN_OR_RETURN(*reshaped_start_indices,
   1043                       start_indices.Reshape(new_shape));
   1044   return std::cref(*reshaped_start_indices);
   1045 }
   1046 
   1047 Status HloEvaluator::HandleGather(HloInstruction* gather) {
   1048   Literal result = Literal::CreateFromShape(gather->shape());
   1049   const Shape& shape = gather->shape();
   1050   const GatherDimensionNumbers& dim_numbers =
   1051       gather->gather_dimension_numbers();
   1052   const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0));
   1053   Literal reshaped_start_indices;
   1054   TF_ASSIGN_OR_RETURN(
   1055       const Literal& start_indices,
   1056       ReshapedGatherIndices(dim_numbers.index_vector_dim(),
   1057                             GetEvaluatedLiteralFor(gather->operand(1)),
   1058                             &reshaped_start_indices));
   1059 
   1060   // We iterate over the gather dimensions in the output shape in an outer loop
   1061   // nest, and iterate over the window dimensions in the output shape in an
   1062   // inner loop nest.
   1063 
   1064   ShapeUtil::IndexIterationSpace start_indices_iteration_space =
   1065       IterationSpaceForOutputBatchIndices(shape, dim_numbers);
   1066   ShapeUtil::IndexIterationSpace offset_indices_iteration_space =
   1067       IterationSpaceForOutputOffsetIndices(
   1068           shape.dimensions_size(), gather->gather_slice_sizes(), dim_numbers);
   1069 
   1070   // Scratch buffers that hold an index in the output shape and the
   1071   // corresponding index in the input shape.
   1072   std::vector<int64> input_index(operand.shape().dimensions_size());
   1073   std::vector<int64> output_index(gather->shape().dimensions_size());
   1074   std::vector<int64> input_index_clamped(operand.shape().dimensions_size());
   1075 
   1076   OutputBatchIndexToInputIndex output_batch_index_to_input_index(
   1077       &gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
   1078       /*output_shape=*/shape, &start_indices);
   1079   OutputOffsetIndexToInputIndex output_offset_index_to_input_index(
   1080       gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
   1081       /*output_shape=*/shape);
   1082 
   1083   const Shape& operand_shape = operand.shape();
   1084 
   1085   auto gather_inner_loop_body =
   1086       [&](absl::Span<const int64> output_window_index,
   1087           absl::Span<const int64> input_gather_index,
   1088           absl::Span<const int64> output_gather_index) -> StatusOr<bool> {
   1089     TF_ASSIGN_OR_RETURN(
   1090         absl::Span<const int64> input_window_index,
   1091         output_offset_index_to_input_index(output_window_index));
   1092     for (int i = 0, e = output_index.size(); i < e; i++) {
   1093       output_index[i] = output_gather_index[i] + output_window_index[i];
   1094       DCHECK_LT(output_index[i], shape.dimensions(i));
   1095     }
   1096     for (int i = 0, e = input_gather_index.size(); i < e; i++) {
   1097       int64 output_dim =
   1098           output_offset_index_to_input_index.input_dim_value_to_output_index(i);
   1099       // If 'output_dim' is -1, it means 'i' is an elided window dim. This means
   1100       // we set the iteration index to 0, so for the purpose of the following
   1101       // calculations we can consider the output dimension size to be 1.
   1102       int64 output_dim_size =
   1103           output_dim == -1 ? 1 : shape.dimensions(output_dim);
   1104       // Clamp the gather index so that the gather region fits in the operand.
   1105       // input_index_clamped[i] = clamp(input_gather_index[i], 0,
   1106       //                                       operand_shape.dimensions(i) -
   1107       //                                       output_dim_size);
   1108       input_index_clamped[i] =
   1109           std::min(operand_shape.dimensions(i) - output_dim_size,
   1110                    std::max(0LL, input_gather_index[i]));
   1111     }
   1112     for (int i = 0, e = input_index.size(); i < e; i++) {
   1113       input_index[i] = input_index_clamped[i] + input_window_index[i];
   1114       DCHECK_GE(input_index[i], 0);
   1115       DCHECK_LT(input_index[i], operand_shape.dimensions(i));
   1116     }
   1117     TF_RETURN_IF_ERROR(
   1118         result.CopyElementFrom(operand, input_index, output_index));
   1119     return true;
   1120   };
   1121 
   1122   auto gather_outer_loop_body =
   1123       [&](absl::Span<const int64> output_gather_index) -> StatusOr<bool> {
   1124     TF_ASSIGN_OR_RETURN(absl::Span<const int64> input_gather_index,
   1125                         output_batch_index_to_input_index(output_gather_index));
   1126     TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
   1127         shape, offset_indices_iteration_space,
   1128         std::bind(gather_inner_loop_body, std::placeholders::_1,
   1129                   input_gather_index, output_gather_index)));
   1130     return true;
   1131   };
   1132 
   1133   TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
   1134       shape, start_indices_iteration_space, gather_outer_loop_body));
   1135   evaluated_[gather] = std::move(result);
   1136   return Status::OK();
   1137 }
   1138 
   1139 Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) {
   1140   const Literal& operand = GetEvaluatedLiteralFor(broadcast->operand(0));
   1141 
   1142   TF_RET_CHECK(broadcast->dimensions().size() == operand.shape().rank())
   1143       << "broadcast dimensions is of size: " << broadcast->dimensions().size()
   1144       << " and rank of operand_to_broadcast is: " << operand.shape().rank();
   1145   // Checks that operand's dimensions are the same as the broadcast's
   1146   // dimensions along the dimensions to be broadcasted.
   1147   for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
   1148     auto operand_dim_size = operand.shape().dimensions(i);
   1149     auto broadcast_dim_size =
   1150         broadcast->shape().dimensions(broadcast->dimensions(i));
   1151     TF_RET_CHECK(operand_dim_size == broadcast_dim_size) << absl::StreamFormat(
   1152         "Operand dimension %d is broadcast to output dimension %d, but the "
   1153         "sizes of these two dims do not match (%d vs %d): %s",
   1154         i, broadcast->dimensions(i), operand_dim_size, broadcast_dim_size,
   1155         broadcast->ToString());
   1156   }
   1157 
   1158   TF_ASSIGN_OR_RETURN(
   1159       evaluated_[broadcast],
   1160       operand.Broadcast(broadcast->shape(), broadcast->dimensions()));
   1161 
   1162   return Status::OK();
   1163 }
   1164 
   1165 Status HloEvaluator::HandleAfterAll(HloInstruction* after_all) {
   1166   evaluated_[after_all] = LiteralUtil::CreateToken();
   1167   return Status::OK();
   1168 }
   1169 
   1170 Status HloEvaluator::HandleAddDependency(HloInstruction* add_dependency) {
   1171   // AddDedendency just forwards its zero-th operand.
   1172   evaluated_[add_dependency] =
   1173       GetEvaluatedLiteralFor(add_dependency->operand(0)).Clone();
   1174   return Status::OK();
   1175 }
   1176 
   1177 Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
   1178   const auto result_shape = get_tuple_element->shape();
   1179   const int64 index = get_tuple_element->tuple_index();
   1180 
   1181   auto operand = get_tuple_element->operand(0);
   1182   TF_ASSIGN_OR_RETURN(
   1183       auto inferred_return_shape,
   1184       ShapeInference::InferGetTupleElementShape(operand->shape(), index));
   1185   TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
   1186       << "return shape set to: " << ShapeUtil::HumanString(result_shape)
   1187       << " but is inferred to be: "
   1188       << ShapeUtil::HumanString(inferred_return_shape);
   1189 
   1190   const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
   1191 
   1192   evaluated_[get_tuple_element] =
   1193       Literal(ShapeUtil::GetTupleElementShape(operand->shape(), index));
   1194   return evaluated_[get_tuple_element].CopyFrom(operand_tuple_literal,
   1195                                                 /*dest_shape_index=*/{},
   1196                                                 /*src_shape_index=*/{index});
   1197 }
   1198 
   1199 Status HloEvaluator::HandleCopy(HloInstruction* copy) {
   1200   TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape()));
   1201   evaluated_[copy] = GetEvaluatedLiteralFor(copy->operand(0)).Clone();
   1202   return Status::OK();
   1203 }
   1204 
   1205 Status HloEvaluator::HandleCall(HloInstruction* call) {
   1206   auto* computation = call->to_apply();
   1207   auto operands = call->operands();
   1208 
   1209   std::vector<const Literal*> arg_literals;
   1210   arg_literals.reserve(operands.size());
   1211   for (auto operand : operands) {
   1212     const Literal& arg_literal = GetEvaluatedLiteralFor(operand);
   1213     arg_literals.push_back(&arg_literal);
   1214   }
   1215 
   1216   HloEvaluator embedded_evaluator;
   1217   embedded_evaluator.set_dynamic_dimension_inference(
   1218       dynamic_dimension_inference_);
   1219   TF_ASSIGN_OR_RETURN(Literal result,
   1220                       embedded_evaluator.Evaluate(*computation, arg_literals));
   1221 
   1222   evaluated_[call] = std::move(result);
   1223   return Status::OK();
   1224 }
   1225 
   1226 Status HloEvaluator::HandleFusion(HloInstruction* fusion) {
   1227   HloModuleConfig config;
   1228   // Attach cloned computation to an empty HLO module so the existing ones are
   1229   // not modified.
   1230   HloModule empty_hlo_module("EmptyModuleForFusion", config);
   1231   HloCloneContext context(&empty_hlo_module);
   1232   auto cloned_fused_computation =
   1233       fusion->fused_instructions_computation()->Clone(
   1234           /*suffix=*/"clone_with_layout", &context);
   1235   for (auto* instruction : cloned_fused_computation->instructions()) {
   1236     if (!LayoutUtil::HasLayout(instruction->shape())) {
   1237       LayoutUtil::SetToDefaultLayout(instruction->mutable_shape());
   1238     }
   1239   }
   1240   auto readded_computation =
   1241       empty_hlo_module.AddEntryComputation(std::move(cloned_fused_computation));
   1242 
   1243   auto operands = fusion->operands();
   1244   std::vector<const Literal*> arg_literals;
   1245   arg_literals.reserve(operands.size());
   1246   for (auto operand : operands) {
   1247     const Literal& arg_literal = GetEvaluatedLiteralFor(operand);
   1248     arg_literals.push_back(&arg_literal);
   1249   }
   1250 
   1251   HloEvaluator embedded_evaluator;
   1252   embedded_evaluator.set_dynamic_dimension_inference(
   1253       dynamic_dimension_inference_);
   1254   TF_ASSIGN_OR_RETURN(Literal result, embedded_evaluator.Evaluate(
   1255                                           *readded_computation, arg_literals));
   1256 
   1257   evaluated_[fusion] = std::move(result);
   1258   return Status::OK();
   1259 }
   1260 
   1261 Status HloEvaluator::HandleConditional(HloInstruction* conditional) {
   1262   const auto& branch_index_literal =
   1263       GetEvaluatedLiteralFor(conditional->operand(0));
   1264   int branch_index;
   1265   if (conditional->operand(0)->shape().element_type() == PRED) {
   1266     branch_index = branch_index_literal.Get<bool>({}) ? 0 : 1;
   1267   } else {
   1268     branch_index = branch_index_literal.Get<int32>({});
   1269     if (branch_index < 0 || branch_index >= conditional->branch_count()) {
   1270       branch_index = conditional->branch_count() - 1;
   1271     }
   1272   }
   1273   const auto& branch_computation_arg =
   1274       GetEvaluatedLiteralFor(conditional->operand(1 + branch_index));
   1275 
   1276   HloEvaluator embedded_evaluator;
   1277   embedded_evaluator.set_dynamic_dimension_inference(
   1278       dynamic_dimension_inference_);
   1279   TF_ASSIGN_OR_RETURN(Literal result,
   1280                       embedded_evaluator.Evaluate(
   1281                           *conditional->branch_computation(branch_index),
   1282                           {&branch_computation_arg}));
   1283 
   1284   evaluated_[conditional] = std::move(result);
   1285   return Status::OK();
   1286 }
   1287 
   1288 Status HloEvaluator::HandleSelect(HloInstruction* select) {
   1289   const auto& pred = GetEvaluatedLiteralFor(select->operand(0));
   1290   const auto& on_true = GetEvaluatedLiteralFor(select->operand(1));
   1291   const auto& on_false = GetEvaluatedLiteralFor(select->operand(2));
   1292 
   1293   // If predicate is of scalar type, no element-wise selection would be needed.
   1294   if (ShapeUtil::IsScalar(pred.shape())) {
   1295     if (pred.Get<bool>({})) {
   1296       evaluated_[select] = on_true.Clone();
   1297     } else {
   1298       evaluated_[select] = on_false.Clone();
   1299     }
   1300     return Status::OK();
   1301   }
   1302 
   1303   return DefaultAction(select);
   1304 }
   1305 
   1306 Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) {
   1307   const auto& pred = GetEvaluatedLiteralFor(tuple_select->operand(0));
   1308   const auto& on_true = GetEvaluatedLiteralFor(tuple_select->operand(1));
   1309   const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2));
   1310 
   1311   if (pred.Get<bool>({})) {
   1312     evaluated_[tuple_select] = on_true.Clone();
   1313   } else {
   1314     evaluated_[tuple_select] = on_false.Clone();
   1315   }
   1316   return Status::OK();
   1317 }
   1318 
   1319 Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
   1320   HloComputation* cond_comp = while_hlo->while_condition();
   1321   HloComputation* body_comp = while_hlo->while_body();
   1322   // Initialize the loop carried valued with the input to the While instruction.
   1323   auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone();
   1324   bool keep_going = true;
   1325   int64 iteration_count = 0;
   1326   HloEvaluator cond_evaluator(max_loop_iterations_);
   1327   cond_evaluator.set_dynamic_dimension_inference(dynamic_dimension_inference_);
   1328   HloEvaluator loop_body_evaluator(max_loop_iterations_);
   1329   loop_body_evaluator.set_dynamic_dimension_inference(
   1330       dynamic_dimension_inference_);
   1331   while (keep_going) {
   1332     if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) {
   1333       return InvalidArgument("Loop %s exceeded loop iteration limit (%d).",
   1334                              while_hlo->name(), max_loop_iterations_);
   1335     }
   1336     TF_ASSIGN_OR_RETURN(auto cond_val,
   1337                         cond_evaluator.Evaluate(*cond_comp, {&lcv}));
   1338     keep_going = cond_val.GetFirstElement<bool>();
   1339     if (keep_going) {
   1340       TF_ASSIGN_OR_RETURN(auto body_val,
   1341                           loop_body_evaluator.Evaluate(*body_comp, {&lcv}));
   1342       VLOG(3) << "Loop iteration result: " << body_val.ToString();
   1343       lcv = std::move(body_val);
   1344       cond_evaluator.ResetVisitStates();
   1345       loop_body_evaluator.ResetVisitStates();
   1346     }
   1347   }
   1348   evaluated_[while_hlo] = std::move(lcv);
   1349   return Status::OK();
   1350 }
   1351 
   1352 namespace {
   1353 template <typename NativeT>
   1354 Literal ExtractLiteralFromIndexPositions(const Literal& from,
   1355                                          absl::Span<int64 const> indices,
   1356                                          bool extract_as_scalar) {
   1357   if (extract_as_scalar) {
   1358     return LiteralUtil::CreateR0<NativeT>(from.Get<NativeT>({indices[0]}));
   1359   }
   1360   // We use a InlinedVector here because we need to convert it to an
   1361   // absl::Span later, and this would not work with std::vector<bool>.
   1362   absl::InlinedVector<NativeT, 10> values;
   1363   for (int64 index : indices) {
   1364     values.push_back(from.Get<NativeT>({index}));
   1365   }
   1366   return LiteralUtil::CreateR1<NativeT>(values);
   1367 }
   1368 
   1369 StatusOr<Literal> ExtractFromIndexPositions(const Literal& from,
   1370                                             absl::Span<int64 const> indices,
   1371                                             bool extract_as_scalar = false) {
   1372   if (extract_as_scalar) {
   1373     CHECK_EQ(indices.size(), 1);
   1374   }
   1375   PrimitiveType type = from.shape().element_type();
   1376   switch (type) {
   1377     case PRED: {
   1378       return ExtractLiteralFromIndexPositions<bool>(from, indices,
   1379                                                     extract_as_scalar);
   1380     }
   1381     case U8: {
   1382       return ExtractLiteralFromIndexPositions<uint8>(from, indices,
   1383                                                      extract_as_scalar);
   1384     }
   1385     case S8: {
   1386       return ExtractLiteralFromIndexPositions<int8>(from, indices,
   1387                                                     extract_as_scalar);
   1388     }
   1389     case BF16: {
   1390       return ExtractLiteralFromIndexPositions<bfloat16>(from, indices,
   1391                                                         extract_as_scalar);
   1392     }
   1393     case F16: {
   1394       return ExtractLiteralFromIndexPositions<Eigen::half>(from, indices,
   1395                                                            extract_as_scalar);
   1396     }
   1397     case U16: {
   1398       return ExtractLiteralFromIndexPositions<uint16>(from, indices,
   1399                                                       extract_as_scalar);
   1400     }
   1401     case S16: {
   1402       return ExtractLiteralFromIndexPositions<int16>(from, indices,
   1403                                                      extract_as_scalar);
   1404     }
   1405     case F32: {
   1406       return ExtractLiteralFromIndexPositions<float>(from, indices,
   1407                                                      extract_as_scalar);
   1408     }
   1409     case U32: {
   1410       return ExtractLiteralFromIndexPositions<uint32>(from, indices,
   1411                                                       extract_as_scalar);
   1412     }
   1413     case S32: {
   1414       return ExtractLiteralFromIndexPositions<int32>(from, indices,
   1415                                                      extract_as_scalar);
   1416     }
   1417     case F64: {
   1418       return ExtractLiteralFromIndexPositions<double>(from, indices,
   1419                                                       extract_as_scalar);
   1420     }
   1421     case U64: {
   1422       return ExtractLiteralFromIndexPositions<uint64>(from, indices,
   1423                                                       extract_as_scalar);
   1424     }
   1425     case S64: {
   1426       return ExtractLiteralFromIndexPositions<int64>(from, indices,
   1427                                                      extract_as_scalar);
   1428     }
   1429     default:
   1430       return InvalidArgument("Unsupported type for Sort: %s",
   1431                              PrimitiveType_Name(type));
   1432   }
   1433 }
   1434 }  // namespace
   1435 
   1436 Status HloEvaluator::HandleSort(HloInstruction* sort) {
   1437   TF_RET_CHECK(sort->operand_count() >= 1)
   1438       << "Expected at least 1 operand for sort";
   1439   for (int64 i = 1; i < sort->operand_count(); ++i) {
   1440     TF_RET_CHECK(ShapeUtil::SameDimensions(sort->operand(0)->shape(),
   1441                                            sort->operand(i)->shape()))
   1442         << "All Sort operands must have the same dimensions";
   1443   }
   1444 
   1445   if (VLOG_IS_ON(3)) {
   1446     for (int64 i = 0; i < sort->operand_count(); ++i) {
   1447       VLOG(3) << "HandleSort operand " << i << " literal: "
   1448               << GetEvaluatedLiteralFor(sort->operand(i)).ToString();
   1449     }
   1450   }
   1451   Shape key_shape = sort->operand(0)->shape();
   1452   auto rank = key_shape.rank();
   1453   std::vector<Literal> result_literals;
   1454   result_literals.reserve(sort->operand_count());
   1455   for (int64 i = 0; i < sort->operand_count(); ++i) {
   1456     result_literals.emplace_back(sort->operand(i)->shape());
   1457   }
   1458   std::vector<int64> zero_base(rank, 0);
   1459   std::vector<int64> increment(rank, 1);
   1460   int64 sort_dim = sort->dimensions(0);
   1461   int64 sort_dim_elements = key_shape.dimensions(sort_dim);
   1462   increment[sort_dim] = sort_dim_elements;
   1463   HloEvaluator embedded_evaluator(max_loop_iterations_);
   1464   // Iterate through each dimension except 'sort_dim'.
   1465   TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
   1466       key_shape, zero_base, AsInt64Slice(key_shape.dimensions()), increment,
   1467       [&](absl::Span<const int64> indices) -> StatusOr<bool> {
   1468         // Extract a slice from each operand literal that corresponds to
   1469         // exactly the row in dimension 'sort_dim'.
   1470         std::vector<int64> limit_indices(indices.begin(), indices.end());
   1471         absl::c_for_each(limit_indices, [](int64& index) { ++index; });
   1472         limit_indices[sort_dim] = sort_dim_elements;
   1473         std::vector<Literal> literals_to_sort;
   1474         literals_to_sort.reserve(sort->operand_count());
   1475         for (int64 i = 0; i < sort->operand_count(); ++i) {
   1476           TF_ASSIGN_OR_RETURN(auto literal_to_sort,
   1477                               GetEvaluatedLiteralFor(sort->operand(i))
   1478                                   .Slice(indices, limit_indices)
   1479                                   .Reshape({sort_dim_elements}));
   1480           literals_to_sort.push_back(std::move(literal_to_sort));
   1481         }
   1482         std::vector<int64> indices_to_sort(sort_dim_elements);
   1483         std::iota(indices_to_sort.begin(), indices_to_sort.end(), 0);
   1484         Status compare_status = Status::OK();
   1485         auto comparator = [sort, &compare_status, &embedded_evaluator,
   1486                            &literals_to_sort](int64 a, int64 b) {
   1487           std::vector<Literal> literals;
   1488           literals.reserve(2 * sort->operand_count());
   1489           for (int64 i = 0; i < sort->operand_count(); ++i) {
   1490             auto lhs = ExtractFromIndexPositions(literals_to_sort[i], {a},
   1491                                                  /*extract_as_scalar=*/true);
   1492             if (!lhs.ok()) {
   1493               compare_status = lhs.status();
   1494               return false;
   1495             }
   1496             literals.push_back(std::move(lhs.ValueOrDie()));
   1497             auto rhs = ExtractFromIndexPositions(literals_to_sort[i], {b},
   1498                                                  /*extract_as_scalar=*/true);
   1499             if (!rhs.ok()) {
   1500               compare_status = rhs.status();
   1501               return false;
   1502             }
   1503             literals.push_back(std::move(rhs.ValueOrDie()));
   1504           }
   1505           std::vector<const Literal*> literal_ptrs;
   1506           absl::c_transform(literals, std::back_inserter(literal_ptrs),
   1507                             [](const Literal& literal) { return &literal; });
   1508 
   1509           auto computed_result =
   1510               embedded_evaluator.Evaluate(*sort->to_apply(), literal_ptrs);
   1511           // Clear visit states so that we can use the evaluator again
   1512           // on the same computation.
   1513           embedded_evaluator.ResetVisitStates();
   1514           if (!computed_result.ok()) {
   1515             compare_status = computed_result.status();
   1516             return false;
   1517           }
   1518           return computed_result.ValueOrDie().Get<bool>({});
   1519         };
   1520         if (Cast<HloSortInstruction>(sort)->is_stable()) {
   1521           std::stable_sort(indices_to_sort.begin(), indices_to_sort.end(),
   1522                            comparator);
   1523         } else {
   1524           std::sort(indices_to_sort.begin(), indices_to_sort.end(), comparator);
   1525         }
   1526         if (!compare_status.ok()) {
   1527           return compare_status;
   1528         }
   1529         std::vector<int64> slice_dimensions(rank, 1);
   1530         slice_dimensions[sort_dim] = sort_dim_elements;
   1531         std::vector<int64> start_indices(rank, 0);
   1532         for (int64 i = 0; i < sort->operand_count(); ++i) {
   1533           TF_ASSIGN_OR_RETURN(
   1534               Literal sorted_literal,
   1535               ExtractFromIndexPositions(literals_to_sort[i], indices_to_sort));
   1536           TF_ASSIGN_OR_RETURN(auto sorted_literal_reshaped,
   1537                               sorted_literal.Reshape(slice_dimensions));
   1538           TF_RETURN_IF_ERROR(result_literals[i].CopySliceFrom(
   1539               sorted_literal_reshaped, start_indices, indices,
   1540               slice_dimensions));
   1541         }
   1542         return true;
   1543       }));
   1544 
   1545   if (sort->operand_count() == 1) {
   1546     evaluated_[sort] = std::move(result_literals[0]);
   1547   } else {
   1548     std::vector<const Literal*> literal_ptrs;
   1549     absl::c_transform(result_literals, std::back_inserter(literal_ptrs),
   1550                       [](const Literal& literal) { return &literal; });
   1551 
   1552     Literal result_tuple = LiteralUtil::MakeTuple(literal_ptrs);
   1553     VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString();
   1554 
   1555     evaluated_[sort] = std::move(result_tuple);
   1556   }
   1557   return Status::OK();
   1558 }
   1559 
   1560 Status HloEvaluator::HandleReduce(HloInstruction* reduce) {
   1561   if (!reduce->shape().IsTuple()) {
   1562     return DefaultAction(reduce);
   1563   } else {
   1564     auto first_element_type = reduce->shape().tuple_shapes(0).element_type();
   1565     for (const auto& tuple_shape : reduce->shape().tuple_shapes()) {
   1566       if (tuple_shape.element_type() != first_element_type) {
   1567         return Unimplemented(
   1568             "Reduce with several outputs that have mixed element types is "
   1569             "unsupported");
   1570       }
   1571     }
   1572     return reduce->Visit(typed_visitors_[first_element_type].get());
   1573   }
   1574 }
   1575 
   1576 Status HloEvaluator::HandleCustomCall(HloInstruction* custom_call) {
   1577   if (!custom_call_handler_) {
   1578     // No handler is registered; this means custom-calls are not allowed.
   1579     return DefaultAction(custom_call);
   1580   }
   1581 
   1582   // Evaluate input operands so the handler has access to the operand data.
   1583   std::vector<const Literal*> operands;
   1584   operands.reserve(custom_call->operand_count());
   1585   for (const HloInstruction* operand : custom_call->operands()) {
   1586     operands.push_back(&GetEvaluatedLiteralFor(operand));
   1587   }
   1588 
   1589   // Synchronously issue the handler to populate the instruction output literal.
   1590   TF_ASSIGN_OR_RETURN(
   1591       auto output, custom_call_handler_(custom_call, absl::MakeSpan(operands)));
   1592 
   1593   evaluated_[custom_call] = std::move(output);
   1594   return Status::OK();
   1595 }
   1596 
   1597 Status HloEvaluator::Preprocess(HloInstruction* hlo) {
   1598   VLOG(2) << "About to visit HLO: " << hlo->ToString();
   1599   return ShapeUtil::ValidateShape(hlo->shape());
   1600 }
   1601 
   1602 Status HloEvaluator::Postprocess(HloInstruction* hlo) {
   1603   VLOG(2) << "Finished visiting " << hlo->ToString()
   1604           << "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString();
   1605   // Out of convenience the literal may have been produced with a different
   1606   // layout. Relayout as indicated by the HLO instruction.
   1607   if (!LayoutUtil::LayoutsInShapesEqual(GetEvaluatedLiteralFor(hlo).shape(),
   1608                                         hlo->shape())) {
   1609     evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape());
   1610   }
   1611   return Status::OK();
   1612 }
   1613 
   1614 namespace {
   1615 template <typename T>
   1616 std::unique_ptr<Array2D<T>> MatmulArray2DImpl(
   1617     const Array2D<T>& lhs, const Array2D<T>& rhs,
   1618     const std::function<void(
   1619         const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, int64 n,
   1620         int64 k, int32 transpose_lhs, int32 transpose_rhs)>& impl_fn) {
   1621   CHECK_EQ(lhs.width(), rhs.height());
   1622   int m = lhs.height();
   1623   int n = rhs.width();
   1624   int k = lhs.width();
   1625   auto result = absl::make_unique<Array2D<T>>(m, n);
   1626   // Because Eigen is a header-oriented library, make sure that the Eigen code
   1627   // is the same as the code used by the CPU backend (otherwise the linker will
   1628   // randomly pick *some* definition).
   1629   impl_fn(
   1630       /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m,
   1631       k,
   1632       /*transpose_lhs=*/0,
   1633       /*transpose_rhs=*/0);
   1634   return result;
   1635 }
   1636 }  // namespace
   1637 
   1638 std::unique_ptr<Array2D<Eigen::half>> HloEvaluator::MatmulArray2D(
   1639     const Array2D<Eigen::half>& lhs, const Array2D<Eigen::half>& rhs) {
   1640   return MatmulArray2DImpl<Eigen::half>(
   1641       lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF16);
   1642 }
   1643 
   1644 std::unique_ptr<Array2D<float>> HloEvaluator::MatmulArray2D(
   1645     const Array2D<float>& lhs, const Array2D<float>& rhs) {
   1646   return MatmulArray2DImpl<float>(
   1647       lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF32);
   1648 }
   1649 
   1650 std::unique_ptr<Array2D<double>> HloEvaluator::MatmulArray2D(
   1651     const Array2D<double>& lhs, const Array2D<double>& rhs) {
   1652   return MatmulArray2DImpl<double>(
   1653       lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF64);
   1654 }
   1655 
   1656 }  // namespace xla
   1657