1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/compiler/xla/service/hlo_evaluator.h" 16 17 #include <algorithm> 18 #include <cmath> 19 #include <cstdlib> 20 #include <functional> 21 #include <iterator> 22 #include <string> 23 #include <type_traits> 24 #include <vector> 25 26 #include "absl/algorithm/container.h" 27 #include "absl/container/inlined_vector.h" 28 #include "absl/memory/memory.h" 29 #include "absl/strings/string_view.h" 30 #include "tensorflow/compiler/xla/index_util.h" 31 #include "tensorflow/compiler/xla/layout_util.h" 32 #include "tensorflow/compiler/xla/literal_util.h" 33 #include "tensorflow/compiler/xla/map_util.h" 34 #include "tensorflow/compiler/xla/primitive_util.h" 35 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" 36 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" 37 #include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" 38 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 39 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 40 #include "tensorflow/compiler/xla/service/hlo_query.h" 41 #include "tensorflow/compiler/xla/service/shape_inference.h" 42 #include "tensorflow/compiler/xla/shape_util.h" 43 #include "tensorflow/compiler/xla/statusor.h" 44 #include "tensorflow/compiler/xla/types.h" 45 #include "tensorflow/compiler/xla/util.h" 46 #include "tensorflow/compiler/xla/window_util.h" 47 #include "tensorflow/core/lib/core/bitmap.h" 48 #include "tensorflow/core/lib/core/errors.h" 49 #include "tensorflow/core/lib/core/status.h" 50 #include "tensorflow/core/platform/logging.h" 51 #include "tensorflow/core/platform/protobuf.h" 52 #include "tensorflow/core/platform/types.h" 53 54 namespace xla { 55 56 namespace { 57 58 template <typename OperandT> 59 StatusOr<Literal> Compare(const Shape& shape, ComparisonDirection direction, 60 LiteralSlice lhs_literal, LiteralSlice rhs_literal) { 61 std::function<bool(OperandT, OperandT)> compare_op; 62 switch (direction) { 63 case ComparisonDirection::kEq: 64 compare_op = [](OperandT lhs_el, OperandT rhs_el) { 65 return lhs_el == rhs_el; 66 }; 67 break; 68 case ComparisonDirection::kNe: 69 compare_op = [](OperandT lhs_el, OperandT rhs_el) { 70 return lhs_el != rhs_el; 71 }; 72 break; 73 case ComparisonDirection::kGe: 74 compare_op = [](OperandT lhs_el, OperandT rhs_el) { 75 return lhs_el >= rhs_el; 76 }; 77 break; 78 case ComparisonDirection::kGt: 79 compare_op = [](OperandT lhs_el, OperandT rhs_el) { 80 return lhs_el > rhs_el; 81 }; 82 break; 83 case ComparisonDirection::kLe: 84 compare_op = [](OperandT lhs_el, OperandT rhs_el) { 85 return lhs_el <= rhs_el; 86 }; 87 break; 88 case ComparisonDirection::kLt: 89 compare_op = [](OperandT lhs_el, OperandT rhs_el) { 90 return lhs_el < rhs_el; 91 }; 92 break; 93 } 94 95 Literal result(shape); 96 TF_RETURN_IF_ERROR( 97 result.Populate<bool>([&](absl::Span<const int64> multi_index) { 98 return compare_op(lhs_literal.Get<OperandT>(multi_index), 99 rhs_literal.Get<OperandT>(multi_index)); 100 })); 101 102 return std::move(result); 103 } 104 105 template <> 106 StatusOr<Literal> Compare<complex64>(const Shape& shape, 107 ComparisonDirection direction, 108 LiteralSlice lhs_literal, 109 LiteralSlice rhs_literal) { 110 std::function<bool(complex64, complex64)> compare_op; 111 switch (direction) { 112 case ComparisonDirection::kEq: 113 compare_op = [](complex64 lhs_el, complex64 rhs_el) { 114 return lhs_el == rhs_el; 115 }; 116 break; 117 case ComparisonDirection::kNe: 118 compare_op = [](complex64 lhs_el, complex64 rhs_el) { 119 return lhs_el != rhs_el; 120 }; 121 break; 122 default: 123 LOG(FATAL) << "unhandled direction for conversion to Comparison: " 124 << ComparisonDirectionToString(direction); 125 } 126 127 Literal result(shape); 128 TF_RETURN_IF_ERROR( 129 result.Populate<bool>([&](absl::Span<const int64> multi_index) { 130 return compare_op(lhs_literal.Get<complex64>(multi_index), 131 rhs_literal.Get<complex64>(multi_index)); 132 })); 133 134 return std::move(result); 135 } 136 137 template <> 138 StatusOr<Literal> Compare<complex128>(const Shape& shape, 139 ComparisonDirection direction, 140 LiteralSlice lhs_literal, 141 LiteralSlice rhs_literal) { 142 std::function<bool(complex128, complex128)> compare_op; 143 switch (direction) { 144 case ComparisonDirection::kEq: 145 compare_op = [](complex128 lhs_el, complex128 rhs_el) { 146 return lhs_el == rhs_el; 147 }; 148 break; 149 case ComparisonDirection::kNe: 150 compare_op = [](complex128 lhs_el, complex128 rhs_el) { 151 return lhs_el != rhs_el; 152 }; 153 break; 154 default: 155 LOG(FATAL) << "unhandled direction for conversion to Comparison: " 156 << ComparisonDirectionToString(direction); 157 } 158 159 Literal result(shape); 160 TF_RETURN_IF_ERROR( 161 result.Populate<bool>([&](absl::Span<const int64> multi_index) { 162 return compare_op(lhs_literal.Get<complex128>(multi_index), 163 rhs_literal.Get<complex128>(multi_index)); 164 })); 165 166 return std::move(result); 167 } 168 169 } // namespace 170 171 // Note that unsupported types by the typed visitor does not necessarily imply 172 // the non-typed HloEvaluator (parent evaluator) would not support them either 173 // in the type-agnostic handler. For e.g., HandleGetTupleElement in the parent 174 // type-agnostic evaluator will be able to accept Tuple primitive type, whereas 175 // HloEvaluatorTypedVisitor cannot. 176 HloEvaluator::HloEvaluator(int64 max_loop_iterations) 177 : max_loop_iterations_(max_loop_iterations) { 178 typed_visitors_[PRED] = 179 absl::make_unique<HloEvaluatorTypedVisitor<bool>>(this); 180 typed_visitors_[U8] = 181 absl::make_unique<HloEvaluatorTypedVisitor<uint8>>(this); 182 typed_visitors_[U16] = 183 absl::make_unique<HloEvaluatorTypedVisitor<uint16>>(this); 184 typed_visitors_[U32] = 185 absl::make_unique<HloEvaluatorTypedVisitor<uint32>>(this); 186 typed_visitors_[U64] = 187 absl::make_unique<HloEvaluatorTypedVisitor<uint64>>(this); 188 typed_visitors_[S8] = absl::make_unique<HloEvaluatorTypedVisitor<int8>>(this); 189 typed_visitors_[S16] = 190 absl::make_unique<HloEvaluatorTypedVisitor<int16>>(this); 191 typed_visitors_[S32] = 192 absl::make_unique<HloEvaluatorTypedVisitor<int32>>(this); 193 typed_visitors_[S64] = 194 absl::make_unique<HloEvaluatorTypedVisitor<int64>>(this); 195 typed_visitors_[F16] = 196 absl::make_unique<HloEvaluatorTypedVisitor<Eigen::half, float>>(this); 197 typed_visitors_[F32] = 198 absl::make_unique<HloEvaluatorTypedVisitor<float>>(this); 199 typed_visitors_[F64] = 200 absl::make_unique<HloEvaluatorTypedVisitor<double>>(this); 201 typed_visitors_[C64] = 202 absl::make_unique<HloEvaluatorTypedVisitor<complex64>>(this); 203 typed_visitors_[C128] = 204 absl::make_unique<HloEvaluatorTypedVisitor<complex128>>(this); 205 206 // Most of the evaluator computations we use don't support BF16 (e.g., 207 // std::ceil, std::tanh). To make evaluator work with BF16, we set all 208 // elementwise computations to be done in F32 and do BF16<->F32 conversion 209 // around the input and the output of the computations. 210 typed_visitors_[BF16] = 211 absl::make_unique<HloEvaluatorTypedVisitor<bfloat16, float>>(this); 212 213 typed_visitors_[TUPLE] = 214 absl::make_unique<FunctionVisitor>([](HloInstruction*) { 215 return Unimplemented( 216 "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE."); 217 }); 218 typed_visitors_[OPAQUE] = 219 absl::make_unique<FunctionVisitor>([](HloInstruction*) { 220 return Unimplemented( 221 "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE."); 222 }); 223 typed_visitors_[TOKEN] = 224 absl::make_unique<FunctionVisitor>([](HloInstruction*) { 225 return Unimplemented( 226 "HloEvaluatorTypedVisitor: unhandled primitive type: TOKEN."); 227 }); 228 } 229 230 StatusOr<Literal> HloEvaluator::Evaluate( 231 const HloComputation& computation, 232 absl::Span<const Literal* const> arg_literals) { 233 CHECK(computation.parent() != nullptr); 234 XLA_VLOG_LINES( 235 2, "HloEvaluator::Evaluate computation:\n" + computation.ToString()); 236 237 if (arg_literals.size() != computation.num_parameters()) { 238 return InvalidArgument( 239 "Expected %d argument%s, but got %d.", computation.num_parameters(), 240 computation.num_parameters() == 1 ? "" : "s", arg_literals.size()); 241 } 242 for (int64 i = 0; i < arg_literals.size(); ++i) { 243 const auto& computation_shape = 244 computation.parameter_instruction(i)->shape(); 245 const auto& arg_shape = arg_literals[i]->shape(); 246 if (!ShapeUtil::Equal(computation_shape, arg_shape)) { 247 return InvalidArgument( 248 "Shape mismatch at parameter %d. Computation expected %s, but arg " 249 "was %s.", 250 i, ShapeUtil::HumanStringWithLayout(computation_shape), 251 ShapeUtil::HumanString(arg_shape)); 252 } 253 } 254 255 evaluated_.clear(); 256 arg_literals_.clear(); 257 for (const auto& literal_ptr : arg_literals) { 258 arg_literals_.push_back(&*literal_ptr); 259 } 260 261 // Re-seed RNG, either from the configuration's seed or a monotonic 262 // per-evaluator seed (which prevents two evaluators from returning the same 263 // random sequence). 264 if (computation.parent()->config().seed()) { 265 seed_ = computation.parent()->config().seed(); 266 } else { 267 // Start global_seed at a (true) random value. 268 static std::atomic<uint64> global_seed{std::random_device()()}; 269 seed_ = global_seed.fetch_add(1); 270 } 271 engine_.seed(seed_); 272 273 TF_RETURN_IF_ERROR(computation.Accept(this)); 274 return GetEvaluatedLiteralFor(computation.root_instruction()).Clone(); 275 } 276 277 StatusOr<Literal> HloEvaluator::Evaluate(HloInstruction* instruction) { 278 if (instruction->opcode() == HloOpcode::kParameter) { 279 return tensorflow::errors::FailedPrecondition( 280 "Cannot evaluate a parameter."); 281 } 282 if (!hlo_query::AllOperandsAreConstants(*instruction)) { 283 return tensorflow::errors::FailedPrecondition( 284 "Not all operands are constants."); 285 } 286 287 arg_literals_.clear(); 288 evaluated_.clear(); 289 290 TF_RETURN_IF_ERROR(Preprocess(instruction)); 291 TF_RETURN_IF_ERROR(instruction->Visit(this)); 292 TF_RETURN_IF_ERROR(Postprocess(instruction)); 293 return GetEvaluatedLiteralFor(instruction).Clone(); 294 } 295 296 bool HloEvaluator::TryEvaluate(HloInstruction* instruction, Literal* result) { 297 CHECK(result != nullptr); 298 auto result_or = Evaluate(instruction); 299 if (!result_or.ok()) { 300 VLOG(1) << "TryEvaluate failed:" << result_or.status(); 301 return false; 302 } 303 304 *result = result_or.ConsumeValueOrDie(); 305 return true; 306 } 307 308 StatusOr<Literal> HloEvaluator::EvaluateWithSubstitutions( 309 const HloInstruction* instruction, 310 const std::unordered_map<const HloInstruction*, const Literal*>& 311 substitutions) { 312 std::vector<std::unique_ptr<HloInstruction>> owned_operands; 313 for (const HloInstruction* operand : instruction->operands()) { 314 auto it = substitutions.find(operand); 315 if (it == substitutions.end()) { 316 owned_operands.push_back(operand->Clone()); 317 } else { 318 owned_operands.push_back( 319 HloInstruction::CreateConstant(it->second->Clone())); 320 } 321 } 322 323 std::vector<HloInstruction*> operands; 324 operands.reserve(owned_operands.size()); 325 for (auto& operand : owned_operands) { 326 operands.push_back(operand.get()); 327 } 328 329 std::unique_ptr<HloInstruction> cloned_instruction = 330 instruction->CloneWithNewOperands(instruction->shape(), operands); 331 auto result = Evaluate(cloned_instruction.get()); 332 333 return result; 334 } 335 336 StatusOr<Literal> HloEvaluator::EvaluateElementwiseBinaryOp( 337 HloOpcode opcode, const Literal& lhs, const Literal& rhs) { 338 std::unique_ptr<HloInstruction> lhs_instr = 339 HloInstruction::CreateConstant(lhs.Clone()); 340 std::unique_ptr<HloInstruction> rhs_instr = 341 HloInstruction::CreateConstant(rhs.Clone()); 342 343 std::unique_ptr<HloInstruction> cloned_instruction = 344 HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(), 345 rhs_instr.get()); 346 auto result = Evaluate(cloned_instruction.get()); 347 348 return result; 349 } 350 351 StatusOr<Literal> HloEvaluator::EvaluateElementwiseUnaryOp( 352 HloOpcode opcode, const Literal& operand) { 353 std::unique_ptr<HloInstruction> operand_instr = 354 HloInstruction::CreateConstant(operand.Clone()); 355 356 std::unique_ptr<HloInstruction> cloned_instruction = 357 HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get()); 358 auto result = Evaluate(cloned_instruction.get()); 359 360 return result; 361 } 362 363 StatusOr<Literal> HloEvaluator::EvaluateDotOp( 364 const DotDimensionNumbers& dim_numbers, 365 const PrecisionConfig& precision_config, const Literal& lhs, 366 const Literal& rhs) { 367 std::unique_ptr<HloInstruction> lhs_instr = 368 HloInstruction::CreateConstant(lhs.Clone()); 369 std::unique_ptr<HloInstruction> rhs_instr = 370 HloInstruction::CreateConstant(rhs.Clone()); 371 372 TF_ASSIGN_OR_RETURN( 373 Shape dot_shape, 374 ShapeInference::InferDotOpShape(lhs.shape(), rhs.shape(), dim_numbers)); 375 376 std::unique_ptr<HloInstruction> cloned_instruction = 377 HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(), 378 dim_numbers, precision_config); 379 return Evaluate(cloned_instruction.get()); 380 } 381 382 Status HloEvaluator::HandleBitcast(HloInstruction* bitcast) { 383 const Literal& operand_literal = GetEvaluatedLiteralFor(bitcast->operand(0)); 384 Literal result(bitcast->shape()); 385 TF_RET_CHECK(operand_literal.size_bytes() == result.size_bytes()); 386 memcpy(result.untyped_data(), operand_literal.untyped_data(), 387 operand_literal.size_bytes()); 388 evaluated_[bitcast] = std::move(result); 389 return Status::OK(); 390 } 391 392 Status HloEvaluator::HandleGetDimensionSize( 393 HloInstruction* get_dimension_size) { 394 HloInstruction* operand = get_dimension_size->mutable_operand(0); 395 int64 dim = get_dimension_size->dimension(); 396 if (dynamic_dimension_inference_ == nullptr) { 397 return InvalidArgument( 398 "Evaluator cannot evaluate get_dimension_size without " 399 "set_dynamic_dimension_inference."); 400 } 401 HloInstruction* dynamic_size = 402 dynamic_dimension_inference_->GetDynamicSize(operand, {}, dim); 403 if (dynamic_size != nullptr) { 404 evaluated_[get_dimension_size] = 405 GetEvaluatedLiteralFor(dynamic_size).Clone(); 406 return Status::OK(); 407 } 408 409 const Shape& shape = get_dimension_size->operand(0)->shape(); 410 Literal output(ShapeUtil::MakeShape(U32, {})); 411 output.PopulateWithValue( 412 static_cast<uint32>(shape.dimensions(get_dimension_size->dimension()))); 413 evaluated_[get_dimension_size] = std::move(output); 414 return Status::OK(); 415 } 416 417 Status HloEvaluator::HandleParameter(HloInstruction* parameter) { 418 // Nothing to do other than sanity checks. Parameters' values are stored in 419 // arg_literals_. 420 CHECK_LT(parameter->parameter_number(), arg_literals_.size()); 421 422 #ifndef NDEBUG 423 const Literal* input_literal = arg_literals_[parameter->parameter_number()]; 424 VLOG(2) << "Parameter evaluated to: " << input_literal->ToString(); 425 DCHECK(ShapeUtil::Equal(parameter->shape(), input_literal->shape())) 426 << "parameter shape is: " << ShapeUtil::HumanString(parameter->shape()) 427 << ", but input literal shape is: " 428 << ShapeUtil::HumanString(input_literal->shape()); 429 #endif 430 431 return Status::OK(); 432 } 433 434 Status HloEvaluator::HandleConstant(HloInstruction*) { return Status::OK(); } 435 436 Status HloEvaluator::HandleReshape(HloInstruction* reshape) { 437 TF_ASSIGN_OR_RETURN( 438 evaluated_[reshape], 439 GetEvaluatedLiteralFor(reshape->operand(0)) 440 .Reshape(AsInt64Slice(reshape->shape().dimensions()))); 441 return Status::OK(); 442 } 443 444 Status HloEvaluator::HandleTranspose(HloInstruction* transpose) { 445 evaluated_[transpose] = GetEvaluatedLiteralFor(transpose->operand(0)) 446 .Transpose(transpose->dimensions()); 447 return Status::OK(); 448 } 449 450 Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { 451 absl::Span<HloInstruction* const> operands(concatenate->operands()); 452 // The result concatenate dimension is going to be the sum of all 453 // concatenate dimensions of the operands taking part of the operation. 454 const Shape& reference_shape = operands[0]->shape(); 455 CHECK(reference_shape.IsArray()); 456 const int64 rank = reference_shape.rank(); 457 const int64 concat_dim = concatenate->dimensions()[0]; 458 CHECK_GE(concat_dim, 0); 459 CHECK_LT(concat_dim, rank); 460 461 DimensionVector concat_dimensions(reference_shape.dimensions().begin(), 462 reference_shape.dimensions().end()); 463 464 for (int64 i = 1; i < operands.size(); ++i) { 465 const Shape& operand_shape = operands[i]->shape(); 466 CHECK(operand_shape.IsArray()); 467 // Accumulate the concat dimension from all tensors taking part to the 468 // operation. 469 concat_dimensions[concat_dim] += 470 ShapeUtil::GetDimension(operand_shape, concat_dim); 471 } 472 473 auto result_literal = LiteralUtil::CreateFromDimensions( 474 reference_shape.element_type(), concat_dimensions); 475 DimensionVector source_indices(rank, 0); 476 DimensionVector dest_indices(concat_dimensions.size(), 0); 477 478 for (auto operand : operands) { 479 const Shape& operand_shape = operand->shape(); 480 TF_RETURN_IF_ERROR(result_literal.CopySliceFrom( 481 GetEvaluatedLiteralFor(operand), source_indices, dest_indices, 482 AsInt64Slice(operand_shape.dimensions()))); 483 dest_indices[concat_dim] += 484 ShapeUtil::GetDimension(operand_shape, concat_dim); 485 } 486 487 evaluated_[concatenate] = std::move(result_literal); 488 return Status::OK(); 489 } 490 491 Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) { 492 auto operand = is_finite->operand(0); 493 auto elem_ty = operand->shape().element_type(); 494 switch (elem_ty) { 495 case PRED: 496 case TUPLE: 497 case OPAQUE: 498 case TOKEN: 499 case S8: 500 case S16: 501 case S32: 502 case S64: 503 case U8: 504 case U16: 505 case U32: 506 case U64: 507 case C64: 508 case C128: 509 // Explicitly enumerate all types in this switch so that when we add a new 510 // type, we'll get a compile error here. 511 case PRIMITIVE_TYPE_INVALID: 512 case PrimitiveType_INT_MIN_SENTINEL_DO_NOT_USE_: 513 case PrimitiveType_INT_MAX_SENTINEL_DO_NOT_USE_: 514 return InvalidArgument( 515 "expected element type in shape to be floating point, but " 516 "got: %s", 517 PrimitiveType_Name(elem_ty)); 518 519 case F16: { 520 auto result_or = ElementWiseUnaryOpImpl<bool, Eigen::half>( 521 is_finite, 522 [](Eigen::half elem_operand) { 523 return std::isfinite(static_cast<float>(elem_operand)); 524 }, 525 GetEvaluatedLiteralFor(operand)); 526 TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or)); 527 break; 528 } 529 case BF16: { 530 auto result_or = ElementWiseUnaryOpImpl<bool, bfloat16>( 531 is_finite, 532 [](bfloat16 elem_operand) { 533 return std::isfinite(static_cast<float>(elem_operand)); 534 }, 535 GetEvaluatedLiteralFor(operand)); 536 TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or)); 537 break; 538 } 539 case F32: { 540 auto result_or = ElementWiseUnaryOpImpl<bool, float>( 541 is_finite, 542 [](float elem_operand) { return std::isfinite(elem_operand); }, 543 GetEvaluatedLiteralFor(operand)); 544 TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or)); 545 break; 546 } 547 case F64: { 548 auto result_or = ElementWiseUnaryOpImpl<bool, double>( 549 is_finite, 550 [](double elem_operand) { return std::isfinite(elem_operand); }, 551 GetEvaluatedLiteralFor(operand)); 552 TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or)); 553 break; 554 } 555 } 556 557 return Status::OK(); 558 } 559 560 Status HloEvaluator::HandleReal(HloInstruction* real) { 561 auto operand = real->operand(0); 562 switch (operand->shape().element_type()) { 563 case BF16: { 564 auto result_or = ElementWiseUnaryOpImpl<bfloat16, bfloat16>( 565 real, [](bfloat16 elem_operand) { return elem_operand; }, 566 GetEvaluatedLiteralFor(operand)); 567 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); 568 break; 569 } 570 case C64: { 571 auto result_or = ElementWiseUnaryOpImpl<float, complex64>( 572 real, [](complex64 elem_operand) { return std::real(elem_operand); }, 573 GetEvaluatedLiteralFor(operand)); 574 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); 575 break; 576 } 577 case C128: { 578 auto result_or = ElementWiseUnaryOpImpl<double, complex128>( 579 real, [](complex128 elem_operand) { return std::real(elem_operand); }, 580 GetEvaluatedLiteralFor(operand)); 581 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); 582 break; 583 } 584 case F16: { 585 auto result_or = ElementWiseUnaryOpImpl<Eigen::half, Eigen::half>( 586 real, [](Eigen::half elem_operand) { return elem_operand; }, 587 GetEvaluatedLiteralFor(operand)); 588 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); 589 break; 590 } 591 case F32: { 592 auto result_or = ElementWiseUnaryOpImpl<float, float>( 593 real, [](float elem_operand) { return elem_operand; }, 594 GetEvaluatedLiteralFor(operand)); 595 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); 596 break; 597 } 598 case F64: { 599 auto result_or = ElementWiseUnaryOpImpl<double, double>( 600 real, [](double elem_operand) { return elem_operand; }, 601 GetEvaluatedLiteralFor(operand)); 602 TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); 603 break; 604 } 605 default: 606 LOG(FATAL) << "HandleReal: unknown/unhandled primitive type: " 607 << PrimitiveType_Name(operand->shape().element_type()); 608 } 609 610 return Status::OK(); 611 } 612 613 Status HloEvaluator::HandleImag(HloInstruction* imag) { 614 auto operand = imag->operand(0); 615 switch (operand->shape().element_type()) { 616 case C64: { 617 auto result_or = ElementWiseUnaryOpImpl<float, complex64>( 618 imag, [](complex64 elem_operand) { return std::imag(elem_operand); }, 619 GetEvaluatedLiteralFor(imag->operand(0))); 620 621 TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or)); 622 break; 623 } 624 case C128: { 625 auto result_or = ElementWiseUnaryOpImpl<double, complex128>( 626 imag, [](complex128 elem_operand) { return std::imag(elem_operand); }, 627 GetEvaluatedLiteralFor(imag->operand(0))); 628 629 TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or)); 630 break; 631 } 632 default: 633 LOG(FATAL) << "HandleImag: unknown/unhandled primitive type: " 634 << PrimitiveType_Name(operand->shape().element_type()); 635 } 636 637 return Status::OK(); 638 } 639 640 Status HloEvaluator::HandleComplex(HloInstruction* complex) { 641 const Literal& real = GetEvaluatedLiteralFor(complex->operand(0)); 642 const Literal& imag = GetEvaluatedLiteralFor(complex->operand(1)); 643 TF_RET_CHECK(ShapeUtil::Compatible(real.shape(), imag.shape())); 644 645 Literal result(complex->shape()); 646 switch (complex->shape().element_type()) { 647 case C64: { 648 TF_RETURN_IF_ERROR( 649 result.Populate<complex64>([&](absl::Span<const int64> multi_index) { 650 return std::complex<float>(real.Get<float>(multi_index), 651 imag.Get<float>(multi_index)); 652 })); 653 break; 654 } 655 case C128: { 656 TF_RETURN_IF_ERROR( 657 result.Populate<complex128>([&](absl::Span<const int64> multi_index) { 658 return std::complex<float>(real.Get<double>(multi_index), 659 imag.Get<double>(multi_index)); 660 })); 661 break; 662 } 663 default: 664 LOG(FATAL) << "HandleComplex: unknown/unhandled primitive type: " 665 << PrimitiveType_Name(complex->shape().element_type()); 666 } 667 668 evaluated_[complex] = std::move(result); 669 return Status::OK(); 670 } 671 672 Status HloEvaluator::HandleCompare(HloInstruction* compare) { 673 ComparisonDirection direction = compare->comparison_direction(); 674 auto lhs = compare->operand(0); 675 auto rhs = compare->operand(1); 676 DCHECK(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) && 677 ShapeUtil::SameDimensions(lhs->shape(), rhs->shape())); 678 679 TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type()); 680 681 const Literal& lhs_literal = GetEvaluatedLiteralFor(lhs); 682 const Literal& rhs_literal = GetEvaluatedLiteralFor(rhs); 683 684 // Note here we switch on the operand's type. 685 switch (lhs->shape().element_type()) { 686 case PRED: { 687 TF_ASSIGN_OR_RETURN( 688 evaluated_[compare], 689 Compare<bool>(compare->shape(), direction, lhs_literal, rhs_literal)); 690 } break; 691 case U8: { 692 TF_ASSIGN_OR_RETURN(evaluated_[compare], 693 Compare<uint8>(compare->shape(), direction, 694 lhs_literal, rhs_literal)); 695 } break; 696 case U16: { 697 TF_ASSIGN_OR_RETURN(evaluated_[compare], 698 Compare<uint16>(compare->shape(), direction, 699 lhs_literal, rhs_literal)); 700 } break; 701 case U32: { 702 TF_ASSIGN_OR_RETURN(evaluated_[compare], 703 Compare<uint32>(compare->shape(), direction, 704 lhs_literal, rhs_literal)); 705 } break; 706 case U64: { 707 TF_ASSIGN_OR_RETURN(evaluated_[compare], 708 Compare<uint64>(compare->shape(), direction, 709 lhs_literal, rhs_literal)); 710 } break; 711 case S8: { 712 TF_ASSIGN_OR_RETURN( 713 evaluated_[compare], 714 Compare<int8>(compare->shape(), direction, lhs_literal, rhs_literal)); 715 } break; 716 case S16: { 717 TF_ASSIGN_OR_RETURN(evaluated_[compare], 718 Compare<int16>(compare->shape(), direction, 719 lhs_literal, rhs_literal)); 720 } break; 721 case S32: { 722 TF_ASSIGN_OR_RETURN(evaluated_[compare], 723 Compare<int32>(compare->shape(), direction, 724 lhs_literal, rhs_literal)); 725 } break; 726 case S64: { 727 TF_ASSIGN_OR_RETURN(evaluated_[compare], 728 Compare<int64>(compare->shape(), direction, 729 lhs_literal, rhs_literal)); 730 } break; 731 case F16: { 732 TF_ASSIGN_OR_RETURN( 733 evaluated_[compare], 734 Compare<half>(compare->shape(), direction, lhs_literal, rhs_literal)); 735 } break; 736 case BF16: { 737 TF_ASSIGN_OR_RETURN(evaluated_[compare], 738 Compare<bfloat16>(compare->shape(), direction, 739 lhs_literal, rhs_literal)); 740 } break; 741 case F32: { 742 TF_ASSIGN_OR_RETURN(evaluated_[compare], 743 Compare<float>(compare->shape(), direction, 744 lhs_literal, rhs_literal)); 745 } break; 746 case F64: { 747 TF_ASSIGN_OR_RETURN(evaluated_[compare], 748 Compare<double>(compare->shape(), direction, 749 lhs_literal, rhs_literal)); 750 } break; 751 case C64: { 752 TF_ASSIGN_OR_RETURN(evaluated_[compare], 753 Compare<complex64>(compare->shape(), direction, 754 lhs_literal, rhs_literal)); 755 } break; 756 case C128: { 757 TF_ASSIGN_OR_RETURN(evaluated_[compare], 758 Compare<complex128>(compare->shape(), direction, 759 lhs_literal, rhs_literal)); 760 } break; 761 default: 762 LOG(FATAL) << "HandleCompare: unknown primitive type: " 763 << PrimitiveType_Name(lhs->shape().element_type()); 764 } 765 766 return Status::OK(); 767 } 768 769 Status HloEvaluator::HandleTuple(HloInstruction* tuple) { 770 std::vector<const Literal*> operand_literals; 771 for (auto operand : tuple->operands()) { 772 operand_literals.push_back(&GetEvaluatedLiteralFor(operand)); 773 } 774 775 evaluated_[tuple] = LiteralUtil::MakeTuple(operand_literals); 776 return Status::OK(); 777 } 778 779 // Returns an ShapeUtil::IndexIterationSpace that iterates over the output batch 780 // dimensions while keeping the rest of the output dimensions clamped to 0. 781 ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices( 782 const Shape& output_shape, const GatherDimensionNumbers& dim_numbers) { 783 int64 output_rank = output_shape.dimensions_size(); 784 std::vector<int64> index_base(output_rank, 0); 785 std::vector<int64> index_count; 786 index_count.reserve(output_rank); 787 for (int64 i = 0; i < output_rank; i++) { 788 bool is_output_batch_dim = 789 !absl::c_binary_search(dim_numbers.offset_dims(), i); 790 index_count.push_back(is_output_batch_dim ? output_shape.dimensions(i) : 1); 791 } 792 793 return {std::move(index_base), std::move(index_count), 794 std::vector<int64>(output_rank, 1)}; 795 } 796 797 // Return an ShapeUtil::IndexIterationSpace that iterates over the output slice 798 // dimensions while keeping the rest of the output dimensions clamped to 0. 799 ShapeUtil::IndexIterationSpace IterationSpaceForOutputOffsetIndices( 800 int64 output_rank, absl::Span<const int64> slice_sizes, 801 const GatherDimensionNumbers& dim_numbers) { 802 std::vector<int64> index_base(output_rank, 0); 803 std::vector<int64> index_count(output_rank, 1); 804 int64 slice_sizes_idx = 0; 805 for (int64 i = 0; i < output_rank; i++) { 806 bool is_output_window_dim = 807 absl::c_binary_search(dim_numbers.offset_dims(), i); 808 if (is_output_window_dim) { 809 while (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), 810 slice_sizes_idx)) { 811 slice_sizes_idx++; 812 } 813 index_count[i] = slice_sizes[slice_sizes_idx++]; 814 } 815 } 816 817 return {std::move(index_base), std::move(index_count), 818 std::vector<int64>(output_rank, 1)}; 819 } 820 821 // This functor computes the contribution of start_indices to an input index 822 // corresponding to an output index. That is, given an output index I, it picks 823 // out the batch indices in I and uses them to look up a starting index, G, from 824 // the start indices tensor, and expands G into the input space according to 825 // start_index_map. 826 class OutputBatchIndexToInputIndex { 827 public: 828 // The constructor does some setup work that is amortized across all 829 // iterations. 830 explicit OutputBatchIndexToInputIndex( 831 const GatherDimensionNumbers* dim_numbers, const Shape& input_shape, 832 const Shape& output_shape, const Literal* start_indices) 833 : dim_numbers_(*dim_numbers), start_indices_(*start_indices) { 834 for (int64 i = 0; i < output_shape.dimensions_size(); i++) { 835 output_dim_is_batch_dims_.push_back( 836 !absl::c_binary_search(dim_numbers_.offset_dims(), i)); 837 } 838 839 for (int64 i = 0; i < input_shape.dimensions_size(); i++) { 840 int64 index_of_input_dim_in_index_vector = 841 std::distance(dim_numbers_.start_index_map().begin(), 842 absl::c_find(dim_numbers_.start_index_map(), i)); 843 if (index_of_input_dim_in_index_vector == 844 dim_numbers_.start_index_map_size()) { 845 input_dim_value_to_index_vector_.push_back(-1); 846 } else { 847 input_dim_value_to_index_vector_.push_back( 848 index_of_input_dim_in_index_vector); 849 } 850 } 851 852 index_vector_index_.resize(start_indices_.shape().dimensions_size()); 853 input_index_.resize(input_shape.dimensions_size()); 854 int64 index_vector_size = 855 start_indices_.shape().dimensions(dim_numbers_.index_vector_dim()); 856 index_vector_.resize(index_vector_size); 857 } 858 859 // Returns the contribution of start_indices to the input index corresponding 860 // to output_index. See gather_inner_loop_body. 861 // 862 // This is conceptually a stateless transformation from output_index to the 863 // gather input index, but: 864 // 865 // - Instead of allocating memory to represent the gather input index on 866 // every invocation we reuse the same storage for the result 867 // (input_index_), mutating it in place. 868 // - Instead of allocating buffers for temporary values like 869 // index_vector_index_ and index_vector on every invocation, we reuse the 870 // same storage for all invocations. 871 // 872 // This returns a Span into memory owned by the class. 873 StatusOr<absl::Span<const int64>> operator()( 874 absl::Span<const int64> output_index) { 875 PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index); 876 TF_RETURN_IF_ERROR(FetchIndexVector()); 877 PropagateIndexVectorToInputIndex(); 878 return absl::Span<const int64>(input_index_); 879 } 880 881 private: 882 // Propagates the batch dimensions from the output index into 883 // index_vector_index_ by mutating index_vector_index_ in place. Does not 884 // update the dim_numbers.index_vector_dim() dimension -- that's the dimension 885 // we iterate over in FetchIndexVector. 886 void PropagateOutputIndexGatherDimsToIndexVectorIndex( 887 absl::Span<const int64> output_index) { 888 int64 index_vector_index_i = 0; 889 for (int64 i = 0, e = output_index.size(); i < e; i++) { 890 if (!output_dim_is_batch_dims_[i]) { 891 continue; 892 } 893 894 if (index_vector_index_i == dim_numbers_.index_vector_dim()) { 895 index_vector_index_i++; 896 } 897 898 index_vector_index_[index_vector_index_i++] = output_index[i]; 899 } 900 } 901 902 // Populates index_vector_ by iterating over start_indices_ according to 903 // index_vector_index_. 904 Status FetchIndexVector() { 905 int64 index_vector_dim = dim_numbers_.index_vector_dim(); 906 for (int64 i = 0, e = index_vector_.size(); i < e; i++) { 907 index_vector_index_[index_vector_dim] = i; 908 TF_ASSIGN_OR_RETURN(index_vector_[i], 909 start_indices_.GetIntegralAsS64(index_vector_index_)); 910 } 911 return Status::OK(); 912 } 913 914 // Populates input_index_. 915 void PropagateIndexVectorToInputIndex() { 916 for (int64 i = 0, e = input_index_.size(); i < e; i++) { 917 if (input_dim_value_to_index_vector_[i] != -1) { 918 input_index_[i] = index_vector_[input_dim_value_to_index_vector_[i]]; 919 } 920 921 // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i] 922 // remains 0, as set by the constructor. 923 } 924 } 925 926 // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of 927 // the input index from the index vector. See 928 // PropagateIndexVectorToInputIndex. 929 std::vector<int64> input_dim_value_to_index_vector_; 930 931 // output_dim_is_batch_dims_[i] is true iff the output index i is a gather 932 // dimension. 933 std::vector<bool> output_dim_is_batch_dims_; 934 935 // The buffer into which we construct an index into start_indices_ to fetch 936 // the index vector. 937 std::vector<int64> index_vector_index_; 938 939 // The index vector fetched from start_indices_. 940 std::vector<int64> index_vector_; 941 942 // The result computed by this functor. operator() returns a Span into 943 // this vector. 944 std::vector<int64> input_index_; 945 946 const GatherDimensionNumbers& dim_numbers_; 947 const Literal& start_indices_; 948 }; 949 950 // This functor computes the contribution of the offset indices in an output 951 // index to an input index. That is, given an output index I it picks out the 952 // output offset indices in I and expands it into an index into the input shape. 953 class OutputOffsetIndexToInputIndex { 954 public: 955 // The constructor does some setup work that is amortized across all 956 // iterations. 957 explicit OutputOffsetIndexToInputIndex( 958 const GatherDimensionNumbers& dim_numbers, const Shape& input_shape, 959 const Shape& output_shape) { 960 std::vector<int64> window_index_to_output_index; 961 int64 output_index_count = 0; 962 for (int64 i = 0; i < output_shape.dimensions_size(); i++) { 963 if (absl::c_binary_search(dim_numbers.offset_dims(), i)) { 964 window_index_to_output_index.push_back(output_index_count++); 965 } else { 966 output_index_count++; 967 } 968 } 969 970 int64 window_dim_count = 0; 971 for (int64 i = 0; i < input_shape.dimensions_size(); i++) { 972 if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { 973 input_dim_value_to_output_index_.push_back(-1); 974 } else { 975 input_dim_value_to_output_index_.push_back( 976 window_index_to_output_index[window_dim_count++]); 977 } 978 } 979 980 input_index_.resize(input_shape.dimensions_size()); 981 } 982 983 // Returns the contribution of the window indices to the input index 984 // corresponding to output_index. See gather_inner_loop_body. 985 // 986 // This is conceptually a stateless transformation from output_index to the 987 // window input index, but instead of allocating memory to represent the 988 // gather input index on every invocation we reuse the same storage for the 989 // result (input_index_), mutating it in place. 990 // 991 // This returns a Span into memory owned by the class. 992 StatusOr<absl::Span<const int64>> operator()( 993 absl::Span<const int64> output_index) { 994 PropagateOutputIndexWindowDimsToInputIndex(output_index); 995 return absl::Span<const int64>(input_index_); 996 } 997 998 // Returns for a given 'input_dim' the corresponding output dimension index, 999 // or -1 if 'input_dim' is an elided window dimension. 1000 int64 input_dim_value_to_output_index(int64 input_dim) { 1001 return input_dim_value_to_output_index_[input_dim]; 1002 } 1003 1004 private: 1005 // Propagates window dimensions from the output index to input_index_ by 1006 // mutating input_index_ in place. 1007 void PropagateOutputIndexWindowDimsToInputIndex( 1008 absl::Span<const int64> output_index) { 1009 for (int64 i = 0, e = input_index_.size(); i < e; i++) { 1010 if (input_dim_value_to_output_index_[i] != -1) { 1011 input_index_[i] = output_index[input_dim_value_to_output_index_[i]]; 1012 } 1013 1014 // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i] 1015 // remains 0, as set by the constructor. 1016 } 1017 } 1018 1019 // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of 1020 // the input index from the output index. See 1021 // PropagateOutputIndexWindowDimsToInputIndex. 1022 std::vector<int64> input_dim_value_to_output_index_; 1023 1024 // The result computed by this functor. operator() returns a Span into 1025 // this vector. 1026 std::vector<int64> input_index_; 1027 }; 1028 1029 // Rehapes the gather indices input to have a trailing degenerate `1` dimension 1030 // if necessary. Hands over the ownership of the newly created literal (if 1031 // there is one) to `reshaped_start_indices`. 1032 static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices( 1033 int64 index_vector_dim, const Literal& start_indices, 1034 Literal* reshaped_start_indices) { 1035 if (start_indices.shape().dimensions_size() != index_vector_dim) { 1036 return std::cref(start_indices); 1037 } 1038 1039 std::vector<int64> new_shape(start_indices.shape().dimensions().begin(), 1040 start_indices.shape().dimensions().end()); 1041 new_shape.push_back(1); 1042 TF_ASSIGN_OR_RETURN(*reshaped_start_indices, 1043 start_indices.Reshape(new_shape)); 1044 return std::cref(*reshaped_start_indices); 1045 } 1046 1047 Status HloEvaluator::HandleGather(HloInstruction* gather) { 1048 Literal result = Literal::CreateFromShape(gather->shape()); 1049 const Shape& shape = gather->shape(); 1050 const GatherDimensionNumbers& dim_numbers = 1051 gather->gather_dimension_numbers(); 1052 const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0)); 1053 Literal reshaped_start_indices; 1054 TF_ASSIGN_OR_RETURN( 1055 const Literal& start_indices, 1056 ReshapedGatherIndices(dim_numbers.index_vector_dim(), 1057 GetEvaluatedLiteralFor(gather->operand(1)), 1058 &reshaped_start_indices)); 1059 1060 // We iterate over the gather dimensions in the output shape in an outer loop 1061 // nest, and iterate over the window dimensions in the output shape in an 1062 // inner loop nest. 1063 1064 ShapeUtil::IndexIterationSpace start_indices_iteration_space = 1065 IterationSpaceForOutputBatchIndices(shape, dim_numbers); 1066 ShapeUtil::IndexIterationSpace offset_indices_iteration_space = 1067 IterationSpaceForOutputOffsetIndices( 1068 shape.dimensions_size(), gather->gather_slice_sizes(), dim_numbers); 1069 1070 // Scratch buffers that hold an index in the output shape and the 1071 // corresponding index in the input shape. 1072 std::vector<int64> input_index(operand.shape().dimensions_size()); 1073 std::vector<int64> output_index(gather->shape().dimensions_size()); 1074 std::vector<int64> input_index_clamped(operand.shape().dimensions_size()); 1075 1076 OutputBatchIndexToInputIndex output_batch_index_to_input_index( 1077 &gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(), 1078 /*output_shape=*/shape, &start_indices); 1079 OutputOffsetIndexToInputIndex output_offset_index_to_input_index( 1080 gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(), 1081 /*output_shape=*/shape); 1082 1083 const Shape& operand_shape = operand.shape(); 1084 1085 auto gather_inner_loop_body = 1086 [&](absl::Span<const int64> output_window_index, 1087 absl::Span<const int64> input_gather_index, 1088 absl::Span<const int64> output_gather_index) -> StatusOr<bool> { 1089 TF_ASSIGN_OR_RETURN( 1090 absl::Span<const int64> input_window_index, 1091 output_offset_index_to_input_index(output_window_index)); 1092 for (int i = 0, e = output_index.size(); i < e; i++) { 1093 output_index[i] = output_gather_index[i] + output_window_index[i]; 1094 DCHECK_LT(output_index[i], shape.dimensions(i)); 1095 } 1096 for (int i = 0, e = input_gather_index.size(); i < e; i++) { 1097 int64 output_dim = 1098 output_offset_index_to_input_index.input_dim_value_to_output_index(i); 1099 // If 'output_dim' is -1, it means 'i' is an elided window dim. This means 1100 // we set the iteration index to 0, so for the purpose of the following 1101 // calculations we can consider the output dimension size to be 1. 1102 int64 output_dim_size = 1103 output_dim == -1 ? 1 : shape.dimensions(output_dim); 1104 // Clamp the gather index so that the gather region fits in the operand. 1105 // input_index_clamped[i] = clamp(input_gather_index[i], 0, 1106 // operand_shape.dimensions(i) - 1107 // output_dim_size); 1108 input_index_clamped[i] = 1109 std::min(operand_shape.dimensions(i) - output_dim_size, 1110 std::max(0LL, input_gather_index[i])); 1111 } 1112 for (int i = 0, e = input_index.size(); i < e; i++) { 1113 input_index[i] = input_index_clamped[i] + input_window_index[i]; 1114 DCHECK_GE(input_index[i], 0); 1115 DCHECK_LT(input_index[i], operand_shape.dimensions(i)); 1116 } 1117 TF_RETURN_IF_ERROR( 1118 result.CopyElementFrom(operand, input_index, output_index)); 1119 return true; 1120 }; 1121 1122 auto gather_outer_loop_body = 1123 [&](absl::Span<const int64> output_gather_index) -> StatusOr<bool> { 1124 TF_ASSIGN_OR_RETURN(absl::Span<const int64> input_gather_index, 1125 output_batch_index_to_input_index(output_gather_index)); 1126 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( 1127 shape, offset_indices_iteration_space, 1128 std::bind(gather_inner_loop_body, std::placeholders::_1, 1129 input_gather_index, output_gather_index))); 1130 return true; 1131 }; 1132 1133 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( 1134 shape, start_indices_iteration_space, gather_outer_loop_body)); 1135 evaluated_[gather] = std::move(result); 1136 return Status::OK(); 1137 } 1138 1139 Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) { 1140 const Literal& operand = GetEvaluatedLiteralFor(broadcast->operand(0)); 1141 1142 TF_RET_CHECK(broadcast->dimensions().size() == operand.shape().rank()) 1143 << "broadcast dimensions is of size: " << broadcast->dimensions().size() 1144 << " and rank of operand_to_broadcast is: " << operand.shape().rank(); 1145 // Checks that operand's dimensions are the same as the broadcast's 1146 // dimensions along the dimensions to be broadcasted. 1147 for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { 1148 auto operand_dim_size = operand.shape().dimensions(i); 1149 auto broadcast_dim_size = 1150 broadcast->shape().dimensions(broadcast->dimensions(i)); 1151 TF_RET_CHECK(operand_dim_size == broadcast_dim_size) << absl::StreamFormat( 1152 "Operand dimension %d is broadcast to output dimension %d, but the " 1153 "sizes of these two dims do not match (%d vs %d): %s", 1154 i, broadcast->dimensions(i), operand_dim_size, broadcast_dim_size, 1155 broadcast->ToString()); 1156 } 1157 1158 TF_ASSIGN_OR_RETURN( 1159 evaluated_[broadcast], 1160 operand.Broadcast(broadcast->shape(), broadcast->dimensions())); 1161 1162 return Status::OK(); 1163 } 1164 1165 Status HloEvaluator::HandleAfterAll(HloInstruction* after_all) { 1166 evaluated_[after_all] = LiteralUtil::CreateToken(); 1167 return Status::OK(); 1168 } 1169 1170 Status HloEvaluator::HandleAddDependency(HloInstruction* add_dependency) { 1171 // AddDedendency just forwards its zero-th operand. 1172 evaluated_[add_dependency] = 1173 GetEvaluatedLiteralFor(add_dependency->operand(0)).Clone(); 1174 return Status::OK(); 1175 } 1176 1177 Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) { 1178 const auto result_shape = get_tuple_element->shape(); 1179 const int64 index = get_tuple_element->tuple_index(); 1180 1181 auto operand = get_tuple_element->operand(0); 1182 TF_ASSIGN_OR_RETURN( 1183 auto inferred_return_shape, 1184 ShapeInference::InferGetTupleElementShape(operand->shape(), index)); 1185 TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) 1186 << "return shape set to: " << ShapeUtil::HumanString(result_shape) 1187 << " but is inferred to be: " 1188 << ShapeUtil::HumanString(inferred_return_shape); 1189 1190 const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand); 1191 1192 evaluated_[get_tuple_element] = 1193 Literal(ShapeUtil::GetTupleElementShape(operand->shape(), index)); 1194 return evaluated_[get_tuple_element].CopyFrom(operand_tuple_literal, 1195 /*dest_shape_index=*/{}, 1196 /*src_shape_index=*/{index}); 1197 } 1198 1199 Status HloEvaluator::HandleCopy(HloInstruction* copy) { 1200 TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape())); 1201 evaluated_[copy] = GetEvaluatedLiteralFor(copy->operand(0)).Clone(); 1202 return Status::OK(); 1203 } 1204 1205 Status HloEvaluator::HandleCall(HloInstruction* call) { 1206 auto* computation = call->to_apply(); 1207 auto operands = call->operands(); 1208 1209 std::vector<const Literal*> arg_literals; 1210 arg_literals.reserve(operands.size()); 1211 for (auto operand : operands) { 1212 const Literal& arg_literal = GetEvaluatedLiteralFor(operand); 1213 arg_literals.push_back(&arg_literal); 1214 } 1215 1216 HloEvaluator embedded_evaluator; 1217 embedded_evaluator.set_dynamic_dimension_inference( 1218 dynamic_dimension_inference_); 1219 TF_ASSIGN_OR_RETURN(Literal result, 1220 embedded_evaluator.Evaluate(*computation, arg_literals)); 1221 1222 evaluated_[call] = std::move(result); 1223 return Status::OK(); 1224 } 1225 1226 Status HloEvaluator::HandleFusion(HloInstruction* fusion) { 1227 HloModuleConfig config; 1228 // Attach cloned computation to an empty HLO module so the existing ones are 1229 // not modified. 1230 HloModule empty_hlo_module("EmptyModuleForFusion", config); 1231 HloCloneContext context(&empty_hlo_module); 1232 auto cloned_fused_computation = 1233 fusion->fused_instructions_computation()->Clone( 1234 /*suffix=*/"clone_with_layout", &context); 1235 for (auto* instruction : cloned_fused_computation->instructions()) { 1236 if (!LayoutUtil::HasLayout(instruction->shape())) { 1237 LayoutUtil::SetToDefaultLayout(instruction->mutable_shape()); 1238 } 1239 } 1240 auto readded_computation = 1241 empty_hlo_module.AddEntryComputation(std::move(cloned_fused_computation)); 1242 1243 auto operands = fusion->operands(); 1244 std::vector<const Literal*> arg_literals; 1245 arg_literals.reserve(operands.size()); 1246 for (auto operand : operands) { 1247 const Literal& arg_literal = GetEvaluatedLiteralFor(operand); 1248 arg_literals.push_back(&arg_literal); 1249 } 1250 1251 HloEvaluator embedded_evaluator; 1252 embedded_evaluator.set_dynamic_dimension_inference( 1253 dynamic_dimension_inference_); 1254 TF_ASSIGN_OR_RETURN(Literal result, embedded_evaluator.Evaluate( 1255 *readded_computation, arg_literals)); 1256 1257 evaluated_[fusion] = std::move(result); 1258 return Status::OK(); 1259 } 1260 1261 Status HloEvaluator::HandleConditional(HloInstruction* conditional) { 1262 const auto& branch_index_literal = 1263 GetEvaluatedLiteralFor(conditional->operand(0)); 1264 int branch_index; 1265 if (conditional->operand(0)->shape().element_type() == PRED) { 1266 branch_index = branch_index_literal.Get<bool>({}) ? 0 : 1; 1267 } else { 1268 branch_index = branch_index_literal.Get<int32>({}); 1269 if (branch_index < 0 || branch_index >= conditional->branch_count()) { 1270 branch_index = conditional->branch_count() - 1; 1271 } 1272 } 1273 const auto& branch_computation_arg = 1274 GetEvaluatedLiteralFor(conditional->operand(1 + branch_index)); 1275 1276 HloEvaluator embedded_evaluator; 1277 embedded_evaluator.set_dynamic_dimension_inference( 1278 dynamic_dimension_inference_); 1279 TF_ASSIGN_OR_RETURN(Literal result, 1280 embedded_evaluator.Evaluate( 1281 *conditional->branch_computation(branch_index), 1282 {&branch_computation_arg})); 1283 1284 evaluated_[conditional] = std::move(result); 1285 return Status::OK(); 1286 } 1287 1288 Status HloEvaluator::HandleSelect(HloInstruction* select) { 1289 const auto& pred = GetEvaluatedLiteralFor(select->operand(0)); 1290 const auto& on_true = GetEvaluatedLiteralFor(select->operand(1)); 1291 const auto& on_false = GetEvaluatedLiteralFor(select->operand(2)); 1292 1293 // If predicate is of scalar type, no element-wise selection would be needed. 1294 if (ShapeUtil::IsScalar(pred.shape())) { 1295 if (pred.Get<bool>({})) { 1296 evaluated_[select] = on_true.Clone(); 1297 } else { 1298 evaluated_[select] = on_false.Clone(); 1299 } 1300 return Status::OK(); 1301 } 1302 1303 return DefaultAction(select); 1304 } 1305 1306 Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) { 1307 const auto& pred = GetEvaluatedLiteralFor(tuple_select->operand(0)); 1308 const auto& on_true = GetEvaluatedLiteralFor(tuple_select->operand(1)); 1309 const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2)); 1310 1311 if (pred.Get<bool>({})) { 1312 evaluated_[tuple_select] = on_true.Clone(); 1313 } else { 1314 evaluated_[tuple_select] = on_false.Clone(); 1315 } 1316 return Status::OK(); 1317 } 1318 1319 Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { 1320 HloComputation* cond_comp = while_hlo->while_condition(); 1321 HloComputation* body_comp = while_hlo->while_body(); 1322 // Initialize the loop carried valued with the input to the While instruction. 1323 auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone(); 1324 bool keep_going = true; 1325 int64 iteration_count = 0; 1326 HloEvaluator cond_evaluator(max_loop_iterations_); 1327 cond_evaluator.set_dynamic_dimension_inference(dynamic_dimension_inference_); 1328 HloEvaluator loop_body_evaluator(max_loop_iterations_); 1329 loop_body_evaluator.set_dynamic_dimension_inference( 1330 dynamic_dimension_inference_); 1331 while (keep_going) { 1332 if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) { 1333 return InvalidArgument("Loop %s exceeded loop iteration limit (%d).", 1334 while_hlo->name(), max_loop_iterations_); 1335 } 1336 TF_ASSIGN_OR_RETURN(auto cond_val, 1337 cond_evaluator.Evaluate(*cond_comp, {&lcv})); 1338 keep_going = cond_val.GetFirstElement<bool>(); 1339 if (keep_going) { 1340 TF_ASSIGN_OR_RETURN(auto body_val, 1341 loop_body_evaluator.Evaluate(*body_comp, {&lcv})); 1342 VLOG(3) << "Loop iteration result: " << body_val.ToString(); 1343 lcv = std::move(body_val); 1344 cond_evaluator.ResetVisitStates(); 1345 loop_body_evaluator.ResetVisitStates(); 1346 } 1347 } 1348 evaluated_[while_hlo] = std::move(lcv); 1349 return Status::OK(); 1350 } 1351 1352 namespace { 1353 template <typename NativeT> 1354 Literal ExtractLiteralFromIndexPositions(const Literal& from, 1355 absl::Span<int64 const> indices, 1356 bool extract_as_scalar) { 1357 if (extract_as_scalar) { 1358 return LiteralUtil::CreateR0<NativeT>(from.Get<NativeT>({indices[0]})); 1359 } 1360 // We use a InlinedVector here because we need to convert it to an 1361 // absl::Span later, and this would not work with std::vector<bool>. 1362 absl::InlinedVector<NativeT, 10> values; 1363 for (int64 index : indices) { 1364 values.push_back(from.Get<NativeT>({index})); 1365 } 1366 return LiteralUtil::CreateR1<NativeT>(values); 1367 } 1368 1369 StatusOr<Literal> ExtractFromIndexPositions(const Literal& from, 1370 absl::Span<int64 const> indices, 1371 bool extract_as_scalar = false) { 1372 if (extract_as_scalar) { 1373 CHECK_EQ(indices.size(), 1); 1374 } 1375 PrimitiveType type = from.shape().element_type(); 1376 switch (type) { 1377 case PRED: { 1378 return ExtractLiteralFromIndexPositions<bool>(from, indices, 1379 extract_as_scalar); 1380 } 1381 case U8: { 1382 return ExtractLiteralFromIndexPositions<uint8>(from, indices, 1383 extract_as_scalar); 1384 } 1385 case S8: { 1386 return ExtractLiteralFromIndexPositions<int8>(from, indices, 1387 extract_as_scalar); 1388 } 1389 case BF16: { 1390 return ExtractLiteralFromIndexPositions<bfloat16>(from, indices, 1391 extract_as_scalar); 1392 } 1393 case F16: { 1394 return ExtractLiteralFromIndexPositions<Eigen::half>(from, indices, 1395 extract_as_scalar); 1396 } 1397 case U16: { 1398 return ExtractLiteralFromIndexPositions<uint16>(from, indices, 1399 extract_as_scalar); 1400 } 1401 case S16: { 1402 return ExtractLiteralFromIndexPositions<int16>(from, indices, 1403 extract_as_scalar); 1404 } 1405 case F32: { 1406 return ExtractLiteralFromIndexPositions<float>(from, indices, 1407 extract_as_scalar); 1408 } 1409 case U32: { 1410 return ExtractLiteralFromIndexPositions<uint32>(from, indices, 1411 extract_as_scalar); 1412 } 1413 case S32: { 1414 return ExtractLiteralFromIndexPositions<int32>(from, indices, 1415 extract_as_scalar); 1416 } 1417 case F64: { 1418 return ExtractLiteralFromIndexPositions<double>(from, indices, 1419 extract_as_scalar); 1420 } 1421 case U64: { 1422 return ExtractLiteralFromIndexPositions<uint64>(from, indices, 1423 extract_as_scalar); 1424 } 1425 case S64: { 1426 return ExtractLiteralFromIndexPositions<int64>(from, indices, 1427 extract_as_scalar); 1428 } 1429 default: 1430 return InvalidArgument("Unsupported type for Sort: %s", 1431 PrimitiveType_Name(type)); 1432 } 1433 } 1434 } // namespace 1435 1436 Status HloEvaluator::HandleSort(HloInstruction* sort) { 1437 TF_RET_CHECK(sort->operand_count() >= 1) 1438 << "Expected at least 1 operand for sort"; 1439 for (int64 i = 1; i < sort->operand_count(); ++i) { 1440 TF_RET_CHECK(ShapeUtil::SameDimensions(sort->operand(0)->shape(), 1441 sort->operand(i)->shape())) 1442 << "All Sort operands must have the same dimensions"; 1443 } 1444 1445 if (VLOG_IS_ON(3)) { 1446 for (int64 i = 0; i < sort->operand_count(); ++i) { 1447 VLOG(3) << "HandleSort operand " << i << " literal: " 1448 << GetEvaluatedLiteralFor(sort->operand(i)).ToString(); 1449 } 1450 } 1451 Shape key_shape = sort->operand(0)->shape(); 1452 auto rank = key_shape.rank(); 1453 std::vector<Literal> result_literals; 1454 result_literals.reserve(sort->operand_count()); 1455 for (int64 i = 0; i < sort->operand_count(); ++i) { 1456 result_literals.emplace_back(sort->operand(i)->shape()); 1457 } 1458 std::vector<int64> zero_base(rank, 0); 1459 std::vector<int64> increment(rank, 1); 1460 int64 sort_dim = sort->dimensions(0); 1461 int64 sort_dim_elements = key_shape.dimensions(sort_dim); 1462 increment[sort_dim] = sort_dim_elements; 1463 HloEvaluator embedded_evaluator(max_loop_iterations_); 1464 // Iterate through each dimension except 'sort_dim'. 1465 TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( 1466 key_shape, zero_base, AsInt64Slice(key_shape.dimensions()), increment, 1467 [&](absl::Span<const int64> indices) -> StatusOr<bool> { 1468 // Extract a slice from each operand literal that corresponds to 1469 // exactly the row in dimension 'sort_dim'. 1470 std::vector<int64> limit_indices(indices.begin(), indices.end()); 1471 absl::c_for_each(limit_indices, [](int64& index) { ++index; }); 1472 limit_indices[sort_dim] = sort_dim_elements; 1473 std::vector<Literal> literals_to_sort; 1474 literals_to_sort.reserve(sort->operand_count()); 1475 for (int64 i = 0; i < sort->operand_count(); ++i) { 1476 TF_ASSIGN_OR_RETURN(auto literal_to_sort, 1477 GetEvaluatedLiteralFor(sort->operand(i)) 1478 .Slice(indices, limit_indices) 1479 .Reshape({sort_dim_elements})); 1480 literals_to_sort.push_back(std::move(literal_to_sort)); 1481 } 1482 std::vector<int64> indices_to_sort(sort_dim_elements); 1483 std::iota(indices_to_sort.begin(), indices_to_sort.end(), 0); 1484 Status compare_status = Status::OK(); 1485 auto comparator = [sort, &compare_status, &embedded_evaluator, 1486 &literals_to_sort](int64 a, int64 b) { 1487 std::vector<Literal> literals; 1488 literals.reserve(2 * sort->operand_count()); 1489 for (int64 i = 0; i < sort->operand_count(); ++i) { 1490 auto lhs = ExtractFromIndexPositions(literals_to_sort[i], {a}, 1491 /*extract_as_scalar=*/true); 1492 if (!lhs.ok()) { 1493 compare_status = lhs.status(); 1494 return false; 1495 } 1496 literals.push_back(std::move(lhs.ValueOrDie())); 1497 auto rhs = ExtractFromIndexPositions(literals_to_sort[i], {b}, 1498 /*extract_as_scalar=*/true); 1499 if (!rhs.ok()) { 1500 compare_status = rhs.status(); 1501 return false; 1502 } 1503 literals.push_back(std::move(rhs.ValueOrDie())); 1504 } 1505 std::vector<const Literal*> literal_ptrs; 1506 absl::c_transform(literals, std::back_inserter(literal_ptrs), 1507 [](const Literal& literal) { return &literal; }); 1508 1509 auto computed_result = 1510 embedded_evaluator.Evaluate(*sort->to_apply(), literal_ptrs); 1511 // Clear visit states so that we can use the evaluator again 1512 // on the same computation. 1513 embedded_evaluator.ResetVisitStates(); 1514 if (!computed_result.ok()) { 1515 compare_status = computed_result.status(); 1516 return false; 1517 } 1518 return computed_result.ValueOrDie().Get<bool>({}); 1519 }; 1520 if (Cast<HloSortInstruction>(sort)->is_stable()) { 1521 std::stable_sort(indices_to_sort.begin(), indices_to_sort.end(), 1522 comparator); 1523 } else { 1524 std::sort(indices_to_sort.begin(), indices_to_sort.end(), comparator); 1525 } 1526 if (!compare_status.ok()) { 1527 return compare_status; 1528 } 1529 std::vector<int64> slice_dimensions(rank, 1); 1530 slice_dimensions[sort_dim] = sort_dim_elements; 1531 std::vector<int64> start_indices(rank, 0); 1532 for (int64 i = 0; i < sort->operand_count(); ++i) { 1533 TF_ASSIGN_OR_RETURN( 1534 Literal sorted_literal, 1535 ExtractFromIndexPositions(literals_to_sort[i], indices_to_sort)); 1536 TF_ASSIGN_OR_RETURN(auto sorted_literal_reshaped, 1537 sorted_literal.Reshape(slice_dimensions)); 1538 TF_RETURN_IF_ERROR(result_literals[i].CopySliceFrom( 1539 sorted_literal_reshaped, start_indices, indices, 1540 slice_dimensions)); 1541 } 1542 return true; 1543 })); 1544 1545 if (sort->operand_count() == 1) { 1546 evaluated_[sort] = std::move(result_literals[0]); 1547 } else { 1548 std::vector<const Literal*> literal_ptrs; 1549 absl::c_transform(result_literals, std::back_inserter(literal_ptrs), 1550 [](const Literal& literal) { return &literal; }); 1551 1552 Literal result_tuple = LiteralUtil::MakeTuple(literal_ptrs); 1553 VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString(); 1554 1555 evaluated_[sort] = std::move(result_tuple); 1556 } 1557 return Status::OK(); 1558 } 1559 1560 Status HloEvaluator::HandleReduce(HloInstruction* reduce) { 1561 if (!reduce->shape().IsTuple()) { 1562 return DefaultAction(reduce); 1563 } else { 1564 auto first_element_type = reduce->shape().tuple_shapes(0).element_type(); 1565 for (const auto& tuple_shape : reduce->shape().tuple_shapes()) { 1566 if (tuple_shape.element_type() != first_element_type) { 1567 return Unimplemented( 1568 "Reduce with several outputs that have mixed element types is " 1569 "unsupported"); 1570 } 1571 } 1572 return reduce->Visit(typed_visitors_[first_element_type].get()); 1573 } 1574 } 1575 1576 Status HloEvaluator::HandleCustomCall(HloInstruction* custom_call) { 1577 if (!custom_call_handler_) { 1578 // No handler is registered; this means custom-calls are not allowed. 1579 return DefaultAction(custom_call); 1580 } 1581 1582 // Evaluate input operands so the handler has access to the operand data. 1583 std::vector<const Literal*> operands; 1584 operands.reserve(custom_call->operand_count()); 1585 for (const HloInstruction* operand : custom_call->operands()) { 1586 operands.push_back(&GetEvaluatedLiteralFor(operand)); 1587 } 1588 1589 // Synchronously issue the handler to populate the instruction output literal. 1590 TF_ASSIGN_OR_RETURN( 1591 auto output, custom_call_handler_(custom_call, absl::MakeSpan(operands))); 1592 1593 evaluated_[custom_call] = std::move(output); 1594 return Status::OK(); 1595 } 1596 1597 Status HloEvaluator::Preprocess(HloInstruction* hlo) { 1598 VLOG(2) << "About to visit HLO: " << hlo->ToString(); 1599 return ShapeUtil::ValidateShape(hlo->shape()); 1600 } 1601 1602 Status HloEvaluator::Postprocess(HloInstruction* hlo) { 1603 VLOG(2) << "Finished visiting " << hlo->ToString() 1604 << "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString(); 1605 // Out of convenience the literal may have been produced with a different 1606 // layout. Relayout as indicated by the HLO instruction. 1607 if (!LayoutUtil::LayoutsInShapesEqual(GetEvaluatedLiteralFor(hlo).shape(), 1608 hlo->shape())) { 1609 evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape()); 1610 } 1611 return Status::OK(); 1612 } 1613 1614 namespace { 1615 template <typename T> 1616 std::unique_ptr<Array2D<T>> MatmulArray2DImpl( 1617 const Array2D<T>& lhs, const Array2D<T>& rhs, 1618 const std::function<void( 1619 const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, int64 n, 1620 int64 k, int32 transpose_lhs, int32 transpose_rhs)>& impl_fn) { 1621 CHECK_EQ(lhs.width(), rhs.height()); 1622 int m = lhs.height(); 1623 int n = rhs.width(); 1624 int k = lhs.width(); 1625 auto result = absl::make_unique<Array2D<T>>(m, n); 1626 // Because Eigen is a header-oriented library, make sure that the Eigen code 1627 // is the same as the code used by the CPU backend (otherwise the linker will 1628 // randomly pick *some* definition). 1629 impl_fn( 1630 /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, 1631 k, 1632 /*transpose_lhs=*/0, 1633 /*transpose_rhs=*/0); 1634 return result; 1635 } 1636 } // namespace 1637 1638 std::unique_ptr<Array2D<Eigen::half>> HloEvaluator::MatmulArray2D( 1639 const Array2D<Eigen::half>& lhs, const Array2D<Eigen::half>& rhs) { 1640 return MatmulArray2DImpl<Eigen::half>( 1641 lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF16); 1642 } 1643 1644 std::unique_ptr<Array2D<float>> HloEvaluator::MatmulArray2D( 1645 const Array2D<float>& lhs, const Array2D<float>& rhs) { 1646 return MatmulArray2DImpl<float>( 1647 lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF32); 1648 } 1649 1650 std::unique_ptr<Array2D<double>> HloEvaluator::MatmulArray2D( 1651 const Array2D<double>& lhs, const Array2D<double>& rhs) { 1652 return MatmulArray2DImpl<double>( 1653 lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF64); 1654 } 1655 1656 } // namespace xla 1657