1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/compiler/xla/service/hlo_evaluator.h" 16 17 #include <algorithm> 18 #include <cmath> 19 #include <cstdlib> 20 #include <functional> 21 #include <string> 22 #include <type_traits> 23 #include <utility> 24 #include <vector> 25 26 #include "tensorflow/compiler/xla/index_util.h" 27 #include "tensorflow/compiler/xla/layout_util.h" 28 #include "tensorflow/compiler/xla/literal_util.h" 29 #include "tensorflow/compiler/xla/map_util.h" 30 #include "tensorflow/compiler/xla/primitive_util.h" 31 #include "tensorflow/compiler/xla/ptr_util.h" 32 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 33 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 34 #include "tensorflow/compiler/xla/service/hlo_query.h" 35 #include "tensorflow/compiler/xla/service/shape_inference.h" 36 #include "tensorflow/compiler/xla/shape_util.h" 37 #include "tensorflow/compiler/xla/types.h" 38 #include "tensorflow/compiler/xla/util.h" 39 #include "tensorflow/compiler/xla/window_util.h" 40 #include "tensorflow/core/lib/core/bitmap.h" 41 #include "tensorflow/core/lib/core/casts.h" 42 #include "tensorflow/core/lib/core/errors.h" 43 #include "tensorflow/core/lib/core/status.h" 44 #include "tensorflow/core/lib/core/stringpiece.h" 45 #include "tensorflow/core/lib/gtl/optional.h" 46 #include "tensorflow/core/platform/logging.h" 47 #include "tensorflow/core/platform/protobuf.h" 48 #include "tensorflow/core/platform/types.h" 49 50 namespace xla { 51 52 namespace { 53 54 template <typename T> 55 struct is_complex_t : public std::false_type {}; 56 57 template <> 58 struct is_complex_t<complex64> : public std::true_type {}; 59 60 template <typename OperandT> 61 StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode, 62 const Literal& lhs_literal, 63 const Literal& rhs_literal) { 64 std::function<bool(OperandT, OperandT)> compare_op; 65 switch (opcode) { 66 case HloOpcode::kEq: 67 compare_op = [](OperandT lhs_el, OperandT rhs_el) { 68 return lhs_el == rhs_el; 69 }; 70 break; 71 case HloOpcode::kNe: 72 compare_op = [](OperandT lhs_el, OperandT rhs_el) { 73 return lhs_el != rhs_el; 74 }; 75 break; 76 case HloOpcode::kGe: 77 compare_op = [](OperandT lhs_el, OperandT rhs_el) { 78 return lhs_el >= rhs_el; 79 }; 80 break; 81 case HloOpcode::kGt: 82 compare_op = [](OperandT lhs_el, OperandT rhs_el) { 83 return lhs_el > rhs_el; 84 }; 85 break; 86 case HloOpcode::kLe: 87 compare_op = [](OperandT lhs_el, OperandT rhs_el) { 88 return lhs_el <= rhs_el; 89 }; 90 break; 91 case HloOpcode::kLt: 92 compare_op = [](OperandT lhs_el, OperandT rhs_el) { 93 return lhs_el < rhs_el; 94 }; 95 break; 96 default: 97 LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: " 98 << HloOpcodeString(opcode); 99 } 100 101 auto result = Literal::CreateFromShape(shape); 102 TF_RETURN_IF_ERROR(result->Populate<bool>( 103 [&](tensorflow::gtl::ArraySlice<int64> multi_index) { 104 return compare_op(lhs_literal.Get<OperandT>(multi_index), 105 rhs_literal.Get<OperandT>(multi_index)); 106 })); 107 108 return std::move(result); 109 } 110 111 template <> 112 StatusOr<std::unique_ptr<Literal>> Compare<complex64>( 113 const Shape& shape, HloOpcode opcode, const Literal& lhs_literal, 114 const Literal& rhs_literal) { 115 std::function<bool(complex64, complex64)> compare_op; 116 switch (opcode) { 117 case HloOpcode::kEq: 118 compare_op = [](complex64 lhs_el, complex64 rhs_el) { 119 return lhs_el == rhs_el; 120 }; 121 break; 122 case HloOpcode::kNe: 123 compare_op = [](complex64 lhs_el, complex64 rhs_el) { 124 return lhs_el != rhs_el; 125 }; 126 break; 127 default: 128 LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: " 129 << HloOpcodeString(opcode); 130 } 131 132 auto result = Literal::CreateFromShape(shape); 133 TF_RETURN_IF_ERROR(result->Populate<bool>( 134 [&](tensorflow::gtl::ArraySlice<int64> multi_index) { 135 return compare_op(lhs_literal.Get<complex64>(multi_index), 136 rhs_literal.Get<complex64>(multi_index)); 137 })); 138 139 return std::move(result); 140 } 141 142 template <typename ReturnT, typename NativeT> 143 StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl( 144 HloInstruction* instruction, 145 const std::function<ReturnT(NativeT)>& unary_op, 146 const Literal& operand_literal) { 147 const auto shape = instruction->shape(); 148 const auto* operand = instruction->operand(0); 149 150 // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is 151 // removed. 152 if (!ShapeUtil::SameDimensions(shape, operand->shape())) { 153 return Unimplemented( 154 "Implicit broadcasting is currently unsupported in HLO evaluator " 155 "Shape Mismatch: %s vs %s", 156 ShapeUtil::HumanString(shape).c_str(), 157 ShapeUtil::HumanString(operand->shape()).c_str()); 158 } 159 160 auto result = Literal::CreateFromShape(shape); 161 162 TF_RETURN_IF_ERROR(result->Populate<ReturnT>( 163 [&](tensorflow::gtl::ArraySlice<int64> multi_index) { 164 return unary_op(operand_literal.Get<NativeT>(multi_index)); 165 })); 166 return std::move(result); 167 } 168 169 // For one particular placement of a window in a base shape (the placement is 170 // represented as `window_count_index`), iterates inside the window. Translates 171 // the window index into base index. If the base index is within bound, call `f` 172 // with the base index. 173 void IterateThroughWindow( 174 const Shape& window_shape, const Window& window, const Shape& base_shape, 175 const tensorflow::gtl::ArraySlice<int64>& window_count_index, 176 const std::function<void(const std::vector<int64>&)>& f) { 177 const int64 rank = ShapeUtil::Rank(base_shape); 178 DimensionVector window_index(rank); 179 std::fill(window_index.begin(), window_index.end(), 0); 180 do { 181 std::vector<int64> base_index(rank); 182 bool out_of_bound = false; 183 for (int64 i = 0; i < rank; ++i) { 184 base_index[i] = window_count_index[i] * window.dimensions(i).stride() + 185 window_index[i] - window.dimensions(i).padding_low(); 186 if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) { 187 out_of_bound = true; 188 break; 189 } 190 } 191 if (!out_of_bound) { 192 f(base_index); 193 } 194 } while (IndexUtil::BumpIndices(window_shape, &window_index)); 195 } 196 197 } // namespace 198 199 template <typename ReturnT, typename ElementwiseT> 200 class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { 201 public: 202 explicit TypedVisitor(HloEvaluator* p) : parent_(p) {} 203 204 // The following higher-order functions convert a function with ElementwiseT 205 // to a function with ReturnT. 206 std::function<ReturnT(ReturnT)> ConvertUnaryFunction( 207 const std::function<ElementwiseT(ElementwiseT)>& unary_op) { 208 return [&unary_op](ReturnT arg) { 209 return static_cast<ReturnT>(unary_op(static_cast<ElementwiseT>(arg))); 210 }; 211 } 212 std::function<ReturnT(ReturnT, ReturnT)> ConvertBinaryFunction( 213 const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>& 214 binary_op) { 215 return [&binary_op](ReturnT arg1, ReturnT arg2) { 216 return static_cast<ReturnT>(binary_op(static_cast<ElementwiseT>(arg1), 217 static_cast<ElementwiseT>(arg2))); 218 }; 219 } 220 std::function<ReturnT(ReturnT, ReturnT, ReturnT)> ConvertTernaryFunction( 221 const std::function<ElementwiseT(ElementwiseT, ElementwiseT, 222 ElementwiseT)>& ternary_op) { 223 return [&ternary_op](ReturnT arg1, ReturnT arg2, ReturnT arg3) { 224 return static_cast<ReturnT>(ternary_op(static_cast<ElementwiseT>(arg1), 225 static_cast<ElementwiseT>(arg2), 226 static_cast<ElementwiseT>(arg3))); 227 }; 228 } 229 230 Status DefaultAction(HloInstruction* hlo_instruction) override { 231 return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", 232 HloOpcodeString(hlo_instruction->opcode()).c_str()); 233 } 234 235 // TODO(b/35950897): many of the stl functions used in the handlers are not 236 // overloaded for every XLA primitive types. 237 238 template <typename NativeT, 239 typename std::enable_if<std::is_unsigned<NativeT>::value>::type* = 240 nullptr> 241 Status HandleAbs(HloInstruction* abs) { 242 TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], 243 ElementWiseUnaryOp(abs, [](NativeT elem_operand) { 244 return elem_operand; 245 })); 246 return Status::OK(); 247 } 248 249 template < 250 typename NativeT, 251 typename std::enable_if<std::is_signed<NativeT>::value || 252 is_complex_t<NativeT>::value>::type* = nullptr> 253 Status HandleAbs(HloInstruction* abs) { 254 TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], 255 ElementWiseUnaryOp(abs, [](ElementwiseT elem_operand) { 256 return std::abs(elem_operand); 257 })); 258 return Status::OK(); 259 } 260 261 Status HandleAbs(HloInstruction* abs) override { 262 return HandleAbs<ElementwiseT>(abs); 263 } 264 265 template < 266 typename NativeT, 267 typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> 268 Status HandleRound(HloInstruction* round) { 269 TF_ASSIGN_OR_RETURN( 270 parent_->evaluated_[round], 271 ElementWiseUnaryOp(round, [](ElementwiseT elem_operand) { 272 return std::round(elem_operand); 273 })); 274 return Status::OK(); 275 } 276 277 template < 278 typename NativeT, 279 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> 280 Status HandleRound(HloInstruction* round) { 281 return InvalidArgument("Unsupported type for Round"); 282 } 283 284 Status HandleRound(HloInstruction* round) override { 285 return HandleRound<ReturnT>(round); 286 } 287 288 Status HandleBroadcast(HloInstruction* broadcast) override { 289 parent_->evaluated_[broadcast] = 290 Literal::CreateFromShape(broadcast->shape()); 291 auto output = parent_->evaluated_[broadcast].get(); 292 const Literal& operand_to_broadcast = 293 parent_->GetEvaluatedLiteralFor(broadcast->operand(0)); 294 std::vector<int64> broadcast_indices( 295 ShapeUtil::Rank(broadcast->operand(0)->shape()), 0); 296 297 TF_RET_CHECK(broadcast->dimensions().size() == 298 ShapeUtil::Rank(operand_to_broadcast.shape())) 299 << "broadcast dimensions is of size: " << broadcast->dimensions().size() 300 << " and rank of operand_to_broadcast is: " 301 << ShapeUtil::Rank(operand_to_broadcast.shape()); 302 // Checks that operand's dimensions are the same as the broadcast's 303 // dimensions along the dimensions to be broadcasted. 304 for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { 305 TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) == 306 operand_to_broadcast.shape().dimensions(i)); 307 } 308 309 return output->Populate<ReturnT>( 310 [&](tensorflow::gtl::ArraySlice<int64> multi_index) { 311 for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { 312 broadcast_indices[i] = multi_index[broadcast->dimensions(i)]; 313 } 314 return operand_to_broadcast.Get<ReturnT>(broadcast_indices); 315 }); 316 } 317 318 template < 319 typename NativeT, 320 typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> 321 Status HandleCeil(HloInstruction* ceil) { 322 TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil], 323 ElementWiseUnaryOp(ceil, [](ElementwiseT elem_operand) { 324 return std::ceil(elem_operand); 325 })); 326 return Status::OK(); 327 } 328 329 template < 330 typename NativeT, 331 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> 332 Status HandleCeil(HloInstruction* ceil) { 333 return InvalidArgument("Unsupported type for Ceil"); 334 } 335 336 Status HandleCeil(HloInstruction* ceil) override { 337 return HandleCeil<ReturnT>(ceil); 338 } 339 340 Status HandleConvert(HloInstruction* convert) override { 341 const HloInstruction* operand = convert->operand(0); 342 TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); 343 TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result, 344 parent_->GetEvaluatedLiteralFor(operand).Convert( 345 convert->shape().element_type())); 346 347 if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { 348 parent_->evaluated_[convert] = std::move(result); 349 } else { 350 parent_->evaluated_[convert] = 351 result->Relayout(convert->shape().layout()); 352 } 353 return Status::OK(); 354 } 355 356 Status HandleExp(HloInstruction* exp) override { 357 TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp], 358 ElementWiseUnaryOp(exp, [](ElementwiseT elem_operand) { 359 return std::exp(elem_operand); 360 })); 361 return Status::OK(); 362 } 363 364 template < 365 typename NativeT, 366 typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> 367 Status HandleFloor(HloInstruction* floor) { 368 TF_ASSIGN_OR_RETURN( 369 parent_->evaluated_[floor], 370 ElementWiseUnaryOp(floor, [](ElementwiseT elem_operand) { 371 return std::floor(elem_operand); 372 })); 373 return Status::OK(); 374 } 375 376 template < 377 typename NativeT, 378 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> 379 Status HandleFloor(HloInstruction* floor) { 380 return InvalidArgument("Unsupported type for Floor"); 381 } 382 383 Status HandleFloor(HloInstruction* floor) override { 384 return HandleFloor<ReturnT>(floor); 385 } 386 387 Status HandleLog(HloInstruction* log) override { 388 TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], 389 ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) { 390 return std::log(elem_operand); 391 })); 392 return Status::OK(); 393 } 394 395 template <typename NativeT, 396 typename std::enable_if< 397 std::is_integral<NativeT>::value && 398 !std::is_same<NativeT, bool>::value>::type* = nullptr> 399 Status HandleNot(HloInstruction* not_) { 400 TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], 401 ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { 402 return ~elem_operand; 403 })); 404 return Status::OK(); 405 } 406 407 template <typename NativeT, typename std::enable_if<std::is_floating_point< 408 NativeT>::value>::type* = nullptr> 409 Status HandleNot(HloInstruction* not_) { 410 TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], 411 ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { 412 return !elem_operand; 413 })); 414 return Status::OK(); 415 } 416 417 template <typename NativeT, 418 typename std::enable_if<std::is_same<NativeT, bool>::value>::type* = 419 nullptr> 420 Status HandleNot(HloInstruction* not_) { 421 TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], 422 ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { 423 return !elem_operand; 424 })); 425 return Status::OK(); 426 } 427 428 template < 429 typename NativeT, 430 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> 431 Status HandleNot(HloInstruction* not_) { 432 return InvalidArgument("Unsupported type for Not"); 433 } 434 435 Status HandleNot(HloInstruction* not_) override { 436 return HandleNot<ElementwiseT>(not_); 437 } 438 439 template <typename NativeT, 440 typename std::enable_if< 441 std::is_signed<NativeT>::value && 442 !std::is_floating_point<NativeT>::value>::type* = nullptr> 443 Status HandleNegate(HloInstruction* negate) { 444 using type = typename std::make_unsigned<NativeT>::type; 445 TF_ASSIGN_OR_RETURN( 446 parent_->evaluated_[negate], 447 ElementWiseUnaryOp(negate, [](ElementwiseT elem_operand) { 448 return NativeT(-type(elem_operand)); 449 })); 450 return Status::OK(); 451 } 452 453 template <typename NativeT, 454 typename std::enable_if< 455 !std::is_signed<NativeT>::value || 456 std::is_floating_point<NativeT>::value>::type* = nullptr> 457 Status HandleNegate(HloInstruction* negate) { 458 TF_ASSIGN_OR_RETURN( 459 parent_->evaluated_[negate], 460 ElementWiseUnaryOp( 461 negate, [](ElementwiseT elem_operand) { return -elem_operand; })); 462 return Status::OK(); 463 } 464 465 Status HandleNegate(HloInstruction* negate) override { 466 return HandleNegate<ReturnT>(negate); 467 } 468 469 template < 470 typename NativeT, 471 typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> 472 Status HandleSign(HloInstruction* sign) { 473 TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], 474 ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { 475 return (ElementwiseT(0) < elem_operand) - 476 (elem_operand < ElementwiseT(0)); 477 })); 478 return Status::OK(); 479 } 480 481 template < 482 typename NativeT, 483 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> 484 Status HandleSign(HloInstruction* sign) { 485 TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], 486 ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { 487 auto abs_val = std::abs(elem_operand); 488 return 0 == abs_val ? ElementwiseT(0) 489 : elem_operand / abs_val; 490 })); 491 return Status::OK(); 492 } 493 494 Status HandleSign(HloInstruction* sign) override { 495 return HandleSign<ReturnT>(sign); 496 } 497 498 template <typename NativeT, typename std::enable_if<std::is_floating_point< 499 NativeT>::value>::type* = nullptr> 500 Status HandleAtan2(HloInstruction* atan2) { 501 TF_ASSIGN_OR_RETURN(parent_->evaluated_[atan2], 502 ElementWiseBinaryOp(atan2, [](ElementwiseT lhs_elem, 503 ElementwiseT rhs_elem) { 504 return std::atan2(lhs_elem, rhs_elem); 505 })); 506 return Status::OK(); 507 } 508 509 template <typename NativeT, typename std::enable_if<!std::is_floating_point< 510 NativeT>::value>::type* = nullptr> 511 Status HandleAtan2(HloInstruction* atan2) { 512 return InvalidArgument("Unsupported type for Atan2"); 513 } 514 515 Status HandleAtan2(HloInstruction* atan2) override { 516 return HandleAtan2<ElementwiseT>(atan2); 517 } 518 519 Status HandleTanh(HloInstruction* tanh) override { 520 TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh], 521 ElementWiseUnaryOp(tanh, [](ElementwiseT elem_operand) { 522 return std::tanh(elem_operand); 523 })); 524 return Status::OK(); 525 } 526 527 template <typename NativeT, 528 typename std::enable_if< 529 std::is_signed<NativeT>::value && 530 !std::is_floating_point<NativeT>::value>::type* = nullptr> 531 Status HandleMultiply(HloInstruction* multiply) { 532 using type = typename std::make_unsigned<NativeT>::type; 533 TF_ASSIGN_OR_RETURN( 534 parent_->evaluated_[multiply], 535 ElementWiseBinaryOp(multiply, 536 [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { 537 return NativeT(type(lhs_elem) * type(rhs_elem)); 538 })); 539 return Status::OK(); 540 } 541 542 template < 543 typename NativeT, 544 typename std::enable_if<std::is_unsigned<NativeT>::value || 545 std::is_floating_point<NativeT>::value || 546 is_complex_t<NativeT>::value>::type* = nullptr> 547 Status HandleMultiply(HloInstruction* multiply) { 548 TF_ASSIGN_OR_RETURN( 549 parent_->evaluated_[multiply], 550 ElementWiseBinaryOp(multiply, 551 [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { 552 return lhs_elem * rhs_elem; 553 })); 554 return Status::OK(); 555 } 556 557 Status HandleMultiply(HloInstruction* multiply) override { 558 return HandleMultiply<ElementwiseT>(multiply); 559 } 560 561 Status HandleSubtract(HloInstruction* subtract) override { 562 TF_ASSIGN_OR_RETURN( 563 parent_->evaluated_[subtract], 564 ElementWiseBinaryOp(subtract, 565 [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { 566 return lhs_elem - rhs_elem; 567 })); 568 return Status::OK(); 569 } 570 571 Status HandleAdd(HloInstruction* add) override { 572 TF_ASSIGN_OR_RETURN(parent_->evaluated_[add], 573 ElementWiseBinaryOp(add, [](ElementwiseT lhs_elem, 574 ElementwiseT rhs_elem) { 575 return lhs_elem + rhs_elem; 576 })); 577 return Status::OK(); 578 } 579 580 Status HandleDivide(HloInstruction* divide) override { 581 TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide], 582 ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem, 583 ElementwiseT rhs_elem) { 584 return lhs_elem / rhs_elem; 585 })); 586 return Status::OK(); 587 } 588 589 template < 590 typename NativeT, 591 typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> 592 Status HandleMaximum(HloInstruction* maximum) { 593 TF_ASSIGN_OR_RETURN( 594 parent_->evaluated_[maximum], 595 ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { 596 return std::fmax(lhs, rhs); 597 })); 598 return Status::OK(); 599 } 600 601 template < 602 typename NativeT, 603 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> 604 Status HandleMaximum(HloInstruction* maximum) { 605 return InvalidArgument("Unsupported type for Maximum"); 606 } 607 608 Status HandleMaximum(HloInstruction* maximum) override { 609 return HandleMaximum<ElementwiseT>(maximum); 610 } 611 612 template < 613 typename NativeT, 614 typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> 615 Status HandleMinimum(HloInstruction* minimum) { 616 TF_ASSIGN_OR_RETURN(parent_->evaluated_[minimum], 617 ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, 618 ElementwiseT rhs_el) { 619 return std::fmin(lhs_el, rhs_el); 620 })); 621 return Status::OK(); 622 } 623 624 template < 625 typename NativeT, 626 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> 627 Status HandleMinimum(HloInstruction* minimum) { 628 return InvalidArgument("Unsupported type for Minimum"); 629 } 630 631 Status HandleMinimum(HloInstruction* minimum) override { 632 return HandleMinimum<ElementwiseT>(minimum); 633 } 634 635 Status HandlePower(HloInstruction* power) override { 636 TF_ASSIGN_OR_RETURN(parent_->evaluated_[power], 637 ElementWiseBinaryOp(power, [](ElementwiseT lhs_el, 638 ElementwiseT rhs_el) { 639 return std::pow(lhs_el, rhs_el); 640 })); 641 return Status::OK(); 642 } 643 644 template < 645 typename NativeT, 646 typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> 647 Status HandleRemainder(HloInstruction* remainder) { 648 TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder], 649 ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el, 650 ElementwiseT rhs_el) { 651 return std::fmod(lhs_el, rhs_el); 652 })); 653 return Status::OK(); 654 } 655 656 template < 657 typename NativeT, 658 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> 659 Status HandleRemainder(HloInstruction* remainder) { 660 return InvalidArgument("Unsupported type for Remainder"); 661 } 662 663 Status HandleRemainder(HloInstruction* remainder) override { 664 return HandleRemainder<ElementwiseT>(remainder); 665 } 666 667 template <typename NativeT, 668 typename std::enable_if<std::is_integral<NativeT>::value>::type* = 669 nullptr> 670 Status HandleAnd(HloInstruction* and_) { 671 TF_ASSIGN_OR_RETURN( 672 parent_->evaluated_[and_], 673 ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { 674 return lhs_el & rhs_el; 675 })); 676 return Status::OK(); 677 } 678 679 template <typename NativeT, typename std::enable_if<std::is_floating_point< 680 NativeT>::value>::type* = nullptr> 681 Status HandleAnd(HloInstruction* and_) { 682 TF_ASSIGN_OR_RETURN( 683 parent_->evaluated_[and_], 684 ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { 685 return lhs_el && rhs_el; 686 })); 687 return Status::OK(); 688 } 689 690 template < 691 typename NativeT, 692 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> 693 Status HandleAnd(HloInstruction* and_) { 694 return InvalidArgument("Unsupported type for And"); 695 } 696 697 Status HandleAnd(HloInstruction* and_) override { 698 return HandleAnd<ElementwiseT>(and_); 699 } 700 701 template <typename NativeT, 702 typename std::enable_if<std::is_integral<NativeT>::value>::type* = 703 nullptr> 704 Status HandleOr(HloInstruction* or_) { 705 TF_ASSIGN_OR_RETURN( 706 parent_->evaluated_[or_], 707 ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { 708 return lhs_el | rhs_el; 709 })); 710 return Status::OK(); 711 } 712 713 template <typename NativeT, typename std::enable_if<std::is_floating_point< 714 NativeT>::value>::type* = nullptr> 715 Status HandleOr(HloInstruction* or_) { 716 TF_ASSIGN_OR_RETURN( 717 parent_->evaluated_[or_], 718 ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { 719 return lhs_el || rhs_el; 720 })); 721 return Status::OK(); 722 } 723 724 template < 725 typename NativeT, 726 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> 727 Status HandleOr(HloInstruction* or_) { 728 return InvalidArgument("Unsupported type for Or"); 729 } 730 731 Status HandleOr(HloInstruction* or_) override { 732 return HandleOr<ElementwiseT>(or_); 733 } 734 735 template <typename NativeT, 736 typename std::enable_if< 737 std::is_integral<NativeT>::value && 738 !std::is_same<NativeT, bool>::value>::type* = nullptr> 739 Status HandleShiftLeft(HloInstruction* shl) { 740 TF_ASSIGN_OR_RETURN( 741 parent_->evaluated_[shl], 742 ElementWiseBinaryOp(shl, [](NativeT lhs_elem, NativeT rhs_elem) { 743 return lhs_elem << rhs_elem; 744 })); 745 return Status::OK(); 746 } 747 748 template <typename NativeT, 749 typename std::enable_if<!std::is_integral<NativeT>::value || 750 std::is_same<NativeT, bool>::value>::type* = 751 nullptr> 752 Status HandleShiftLeft(HloInstruction*) { 753 return InvalidArgument("Unsupported type for ShiftLeft"); 754 } 755 756 Status HandleShiftLeft(HloInstruction* shl) override { 757 return HandleShiftLeft<ElementwiseT>(shl); 758 } 759 template <typename NativeT, 760 typename std::enable_if< 761 std::is_integral<NativeT>::value && 762 !std::is_same<NativeT, bool>::value>::type* = nullptr> 763 Status HandleShiftRightArithmetic(HloInstruction* shr) { 764 typedef typename std::make_signed<NativeT>::type SignedT; 765 TF_ASSIGN_OR_RETURN( 766 parent_->evaluated_[shr], 767 ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { 768 return static_cast<NativeT>(static_cast<SignedT>(lhs_elem) >> 769 rhs_elem); 770 })); 771 return Status::OK(); 772 } 773 774 template <typename NativeT, 775 typename std::enable_if<!std::is_integral<NativeT>::value || 776 std::is_same<NativeT, bool>::value>::type* = 777 nullptr> 778 Status HandleShiftRightArithmetic(HloInstruction*) { 779 return InvalidArgument("Unsupported type for ShiftRightArithmetic"); 780 } 781 782 Status HandleShiftRightArithmetic(HloInstruction* shra) override { 783 return HandleShiftRightArithmetic<ElementwiseT>(shra); 784 } 785 786 template <typename NativeT, 787 typename std::enable_if< 788 std::is_integral<NativeT>::value && 789 !std::is_same<NativeT, bool>::value>::type* = nullptr> 790 Status HandleShiftRightLogical(HloInstruction* shr) { 791 typedef typename std::make_unsigned<NativeT>::type UnsignedT; 792 TF_ASSIGN_OR_RETURN( 793 parent_->evaluated_[shr], 794 ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { 795 // If shift amount is greater than the number of bits, then return 0. 796 if (rhs_elem >= sizeof(UnsignedT) * CHAR_BIT) { 797 return static_cast<NativeT>(0); 798 } 799 return static_cast<NativeT>(static_cast<UnsignedT>(lhs_elem) >> 800 rhs_elem); 801 })); 802 return Status::OK(); 803 } 804 805 template <typename NativeT, 806 typename std::enable_if<!std::is_integral<NativeT>::value || 807 std::is_same<NativeT, bool>::value>::type* = 808 nullptr> 809 Status HandleShiftRightLogical(HloInstruction*) { 810 return InvalidArgument("Unsupported type for ShiftRightLogical"); 811 } 812 813 Status HandleShiftRightLogical(HloInstruction* shrl) override { 814 return HandleShiftRightLogical<ElementwiseT>(shrl); 815 } 816 817 template < 818 typename NativeT, 819 typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> 820 Status HandleClamp(HloInstruction* clamp) { 821 std::function<ElementwiseT(ElementwiseT, ElementwiseT, ElementwiseT)> 822 clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { 823 return std::fmax(low, std::fmin(value, high)); 824 }; 825 TF_ASSIGN_OR_RETURN( 826 parent_->evaluated_[clamp], 827 ElementwiseTernaryOp(clamp, 828 std::move(ConvertTernaryFunction(clamp_op)))); 829 return Status::OK(); 830 } 831 832 template < 833 typename NativeT, 834 typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> 835 Status HandleClamp(HloInstruction*) { 836 return InvalidArgument("Unsupported type for Clamp"); 837 } 838 839 Status HandleClamp(HloInstruction* clamp) override { 840 return HandleClamp<ElementwiseT>(clamp); 841 } 842 843 Status HandleSelect(HloInstruction* select) override { 844 CHECK(!ShapeUtil::IsTuple(select->shape())); 845 std::function<ReturnT(bool, ReturnT, ReturnT)> select_op = 846 [](bool pred, ReturnT on_true, ReturnT on_false) { 847 if (pred) { 848 return on_true; 849 } 850 return on_false; 851 }; 852 TF_ASSIGN_OR_RETURN(parent_->evaluated_[select], 853 ElementwiseTernaryOp(select, std::move(select_op))); 854 return Status::OK(); 855 } 856 857 Status HandleReverse(HloInstruction* reverse) override { 858 const auto result_shape = reverse->shape(); 859 const auto reverse_dimensions = reverse->dimensions(); 860 861 auto operand = reverse->operand(0); 862 TF_ASSIGN_OR_RETURN(auto inferred_return_shape, 863 ShapeInference::InferReverseShape(operand->shape(), 864 reverse_dimensions)); 865 866 TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) 867 << "return shape set to: " << ShapeUtil::HumanString(result_shape) 868 << " but is inferred to be: " 869 << ShapeUtil::HumanString(inferred_return_shape); 870 871 const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); 872 auto result = Literal::CreateFromShape(result_shape); 873 874 TF_RETURN_IF_ERROR(result->Populate<ReturnT>( 875 [&](tensorflow::gtl::ArraySlice<int64> out_index) { 876 std::vector<int64> from_index(out_index.begin(), out_index.end()); 877 for (const int64 dim : reverse_dimensions) { 878 from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim]; 879 } 880 return operand_literal.Get<ReturnT>(from_index); 881 })); 882 883 parent_->evaluated_[reverse] = std::move(result); 884 return Status::OK(); 885 } 886 887 Status HandleConvolution(HloInstruction* conv) override { 888 auto lhs = conv->operand(0); 889 auto rhs = conv->operand(1); 890 const auto& window = conv->window(); 891 const Shape& result_shape = conv->shape(); 892 const Shape& lhs_shape = lhs->shape(); 893 const Shape& rhs_shape = rhs->shape(); 894 895 TF_CHECK_OK(ShapeUtil::ValidateShape(lhs_shape)); 896 TF_CHECK_OK(ShapeUtil::ValidateShape(rhs_shape)); 897 CHECK(ShapeUtil::IsArray(lhs_shape)); 898 CHECK(ShapeUtil::IsArray(rhs_shape)); 899 CHECK(ShapeUtil::SameElementType(lhs_shape, rhs_shape)); 900 CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape)); 901 902 const auto& dnums = conv->convolution_dimension_numbers(); 903 const int64 num_spatial_dims = dnums.output_spatial_dimensions_size(); 904 CHECK_EQ(num_spatial_dims, dnums.input_spatial_dimensions_size()); 905 CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size()); 906 CHECK_GE(num_spatial_dims, 0); 907 CHECK_EQ(window.dimensions_size(), num_spatial_dims); 908 909 const auto lhs_rank = ShapeUtil::Rank(lhs_shape); 910 const auto rhs_rank = ShapeUtil::Rank(rhs_shape); 911 912 CHECK_EQ(num_spatial_dims + 2, lhs_rank); 913 CHECK_EQ(num_spatial_dims + 2, rhs_rank); 914 915 TF_ASSIGN_OR_RETURN(auto inferred_return_shape, 916 ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, 917 window, dnums)); 918 CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) 919 << "return shape set to: " << ShapeUtil::HumanString(result_shape) 920 << " but is inferred to be: " 921 << ShapeUtil::HumanString(inferred_return_shape); 922 923 const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); 924 const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); 925 926 // Dimension number applicable for input (lhs). 927 const int64 input_batch_dim = dnums.input_batch_dimension(); 928 const int64 input_z_dim = dnums.input_feature_dimension(); 929 // Dimension number applicable for kernel (rhs). 930 const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension(); 931 const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension(); 932 // Dimension number applicable for output. 933 const int64 output_batch_dim = dnums.output_batch_dimension(); 934 const int64 output_z_dim = dnums.output_feature_dimension(); 935 936 const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); 937 938 std::vector<int64> window_dimension_sizes; 939 for (auto i : dnums.kernel_spatial_dimensions()) { 940 window_dimension_sizes.push_back(ShapeUtil::GetDimension(rhs_shape, i)); 941 } 942 943 const Shape& window_shape = 944 ShapeUtil::MakeShape(rhs_shape.element_type(), window_dimension_sizes); 945 946 DimensionVector lhs_index(lhs_rank); 947 DimensionVector rhs_index(rhs_rank); 948 DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size()); 949 950 auto func = [&](tensorflow::gtl::ArraySlice<int64> out_index) { 951 ElementwiseT result_val = static_cast<ElementwiseT>(0); 952 953 std::fill(lhs_index.begin(), lhs_index.end(), 0); 954 std::fill(rhs_index.begin(), rhs_index.end(), 0); 955 std::fill(rhs_spatial_index.begin(), rhs_spatial_index.end(), 0); 956 957 lhs_index[input_batch_dim] = out_index[output_batch_dim]; 958 rhs_index[kernel_output_z_dim] = out_index[output_z_dim]; 959 960 // Convolve input feature with kernel. 961 do { 962 for (int64 iz = 0; iz < z_size; ++iz) { 963 lhs_index[input_z_dim] = iz; 964 rhs_index[kernel_input_z_dim] = iz; 965 966 // Find corresponding spatial dimension index for input (lhs). 967 for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { 968 // Spatial dimension number for input (lhs) and output. 969 const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki); 970 const int64 output_spatial_dim = 971 dnums.output_spatial_dimensions(ki); 972 973 // Calculate lhs (input) index without taking base dilation into 974 // account. 975 const auto& window_dim = window.dimensions(ki); 976 const int64 undilated_index = 977 out_index[output_spatial_dim] * window_dim.stride() - 978 window_dim.padding_low() + 979 rhs_spatial_index[ki] * window_dim.window_dilation(); 980 // Skip if the lhs (input) index is to be dilated. As an 981 // optimization, skip this mod if there's no dilation. 982 if (window_dim.base_dilation() > 1 && 983 undilated_index % window_dim.base_dilation() != 0) { 984 goto cnt; 985 } 986 987 // Calculate the actual lhs (input) index after dilation. As an 988 // optimization, skip this integer divide if there's no dilation. 989 if (window_dim.base_dilation() > 1) { 990 lhs_index[input_spatial_dim] = 991 undilated_index / window_dim.base_dilation(); 992 } else { 993 lhs_index[input_spatial_dim] = undilated_index; 994 } 995 996 // Skip if input index is not in bound. 997 if (!(lhs_index[input_spatial_dim] >= 0 && 998 lhs_index[input_spatial_dim] < 999 lhs_shape.dimensions(input_spatial_dim))) { 1000 goto cnt; 1001 } 1002 1003 rhs_index[dnums.kernel_spatial_dimensions(ki)] = 1004 window_dim.window_reversal() 1005 ? ((window_dim.size() - 1) - rhs_spatial_index[ki]) 1006 : rhs_spatial_index[ki]; 1007 } 1008 1009 result_val += 1010 static_cast<ElementwiseT>(lhs_literal.Get<ReturnT>(lhs_index)) * 1011 static_cast<ElementwiseT>(rhs_literal.Get<ReturnT>(rhs_index)); 1012 } 1013 cnt : {} 1014 } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index)); 1015 1016 return static_cast<ReturnT>(result_val); 1017 }; 1018 1019 auto result = Literal::CreateFromShape(result_shape); 1020 TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func)); 1021 1022 parent_->evaluated_[conv] = std::move(result); 1023 return Status::OK(); 1024 } 1025 1026 Status HandleDot(HloInstruction* dot) override { 1027 auto lhs = dot->operand(0); 1028 auto rhs = dot->operand(1); 1029 CHECK(ShapeUtil::IsArray(dot->shape())); 1030 CHECK(ShapeUtil::IsArray(lhs->shape())); 1031 CHECK(ShapeUtil::IsArray(rhs->shape())); 1032 1033 const auto& dnums = dot->dot_dimension_numbers(); 1034 1035 const auto lhs_rank = ShapeUtil::Rank(lhs->shape()); 1036 const auto rhs_rank = ShapeUtil::Rank(rhs->shape()); 1037 1038 CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); 1039 CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); 1040 1041 // There must be 1 and only 1 Contracting dimension for lhs and rhs. 1042 CHECK_EQ(dnums.lhs_contracting_dimensions_size(), 1); 1043 CHECK_EQ(dnums.rhs_contracting_dimensions_size(), 1); 1044 const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); 1045 const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); 1046 // Contracted dimension sizes must be the same. 1047 CHECK_EQ(lhs->shape().dimensions(lhs_contracting_dimension), 1048 rhs->shape().dimensions(rhs_contracting_dimension)) 1049 << "lhs contracted dimension: " 1050 << lhs->shape().dimensions(lhs_contracting_dimension) 1051 << " rhs contracted dimension: " 1052 << rhs->shape().dimensions(rhs_contracting_dimension); 1053 const int64 contracted_dimension_size = 1054 lhs->shape().dimensions(lhs_contracting_dimension); 1055 1056 const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); 1057 const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); 1058 1059 auto result = Literal::CreateFromShape(dot->shape()); 1060 1061 CHECK_EQ(dnums.lhs_batch_dimensions_size(), 1062 dnums.rhs_batch_dimensions_size()); 1063 1064 std::vector<int64> lhs_non_contracting_dims; 1065 for (int64 i = 0; i < lhs_rank; i++) { 1066 if (i != lhs_contracting_dimension) { 1067 lhs_non_contracting_dims.push_back(i); 1068 } 1069 } 1070 1071 std::vector<int64> rhs_non_batch_non_contracting_dims; 1072 tensorflow::gtl::FlatSet<int64> batch_dims_set( 1073 dnums.rhs_batch_dimensions().begin(), 1074 dnums.rhs_batch_dimensions().end()); 1075 for (int64 i = 0; i < rhs_rank; i++) { 1076 if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) { 1077 rhs_non_batch_non_contracting_dims.push_back(i); 1078 } 1079 } 1080 1081 const int64 batch_dim_size = dnums.lhs_batch_dimensions_size(); 1082 const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size(); 1083 1084 DimensionVector lhs_index(lhs_rank); 1085 DimensionVector rhs_index(rhs_rank); 1086 TF_RETURN_IF_ERROR(result->Populate<ReturnT>( 1087 [&](tensorflow::gtl::ArraySlice<int64> result_index) { 1088 ElementwiseT result_val = static_cast<ElementwiseT>(0); 1089 1090 // Find the corresponding non-contracting indices for lhs and rhs. 1091 // 1092 // For `result_index`, its batch dimension, if exists, will be at the 1093 // same dimension as the batch dimension of lhs and rhs. More 1094 // specifically: 1095 // - For lhs, the non-contracting dimensions, including the batch 1096 // dimension have the same index as the `result_index`. 1097 // - For rhs, the batch dimension is set seperately from other 1098 // non-contracting dimensions, since these other non-contracting 1099 // dimensions in rhs follow the non-contracting dimensions of lhs in 1100 // the resulting index. 1101 // 1102 // As an example, for a resulting index: 1103 // result_index [result_batch, result_x, result_y] 1104 // the effecting lhs and rhs indices are: 1105 // lhs [result_batch, lhs_non_contracting_dim, contracting_dim 1106 // rhs [result_batch, contracting_dim, rhs_non_contracting_dim] 1107 // `result_x` is only affected by the lhs_non_contracting_dim and 1108 // likewise `result_y` only depends on rhs_non_contracting_dim. 1109 // 1110 // so we can look up the lhs and rhs indices by: 1111 // 1112 // lhs: 1113 // batch index is the same as `result_batch`. 1114 // non-contracting dimension is the same as 1115 // result_index[lhs_non_contracting_dim] 1116 // rhs: 1117 // batch index: the same as `result_batch`. 1118 // non-contracting dimension index: *not* the same as 1119 // result_index[rhs_non_contractng_dim], since the 1120 // non-contracting dimensions of lhs are included in the 1121 // result_index first. Instead, the non_contracting_dim of rhs must 1122 // be calculated as following: 1123 // lhs_non_contracting_dimensions_size + 1124 // (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1 1125 // 1126 // Note that (rhs_non_batch_contracting_dim - batch_dim_size) is 1127 // the index offset to the result_index that only depends on 1128 // the non_batch and non-contracting dimensions of rhs. -1 at the 1129 // end translates size to index. 1130 for (auto i : lhs_non_contracting_dims) { 1131 lhs_index[i] = result_index[i]; 1132 } 1133 for (auto i : dnums.rhs_batch_dimensions()) { 1134 rhs_index[i] = result_index[i]; 1135 } 1136 for (auto i : rhs_non_batch_non_contracting_dims) { 1137 const int64 rhs_non_batch_non_contracting_dim = 1138 lhs_non_contracting_size + (i - batch_dim_size) - 1; 1139 rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim]; 1140 } 1141 1142 // Accumulates resulting product along the contracted dimension. 1143 for (int64 i = 0; i < contracted_dimension_size; ++i) { 1144 lhs_index[lhs_contracting_dimension] = i; 1145 rhs_index[rhs_contracting_dimension] = i; 1146 1147 result_val += 1148 static_cast<ElementwiseT>(lhs_literal.Get<ReturnT>(lhs_index)) * 1149 static_cast<ElementwiseT>(rhs_literal.Get<ReturnT>(rhs_index)); 1150 } 1151 1152 return static_cast<ReturnT>(result_val); 1153 })); 1154 1155 parent_->evaluated_[dot] = std::move(result); 1156 return Status::OK(); 1157 } 1158 1159 Status HandlePad(HloInstruction* pad) override { 1160 CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape())); 1161 // Padding value must be scalar. 1162 CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape())); 1163 CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()), 1164 pad->padding_config().dimensions_size()); 1165 1166 TF_ASSIGN_OR_RETURN(auto inferred_return_shape, 1167 ShapeInference::InferPadShape( 1168 /*operand_shape=*/pad->operand(0)->shape(), 1169 /*padding_value_shape=*/pad->operand(1)->shape(), 1170 /*padding_config=*/pad->padding_config())); 1171 CHECK(ShapeUtil::Compatible(pad->shape(), inferred_return_shape)) 1172 << "return shape is set to: " << ShapeUtil::HumanString(pad->shape()) 1173 << "but is inferred to be: " 1174 << ShapeUtil::HumanString(inferred_return_shape); 1175 1176 // Create new HLO of padded shape with padding value. 1177 ReturnT scalar = 1178 parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get<ReturnT>({}); 1179 auto result = Literal::CreateFromShape(pad->shape()); 1180 TF_RETURN_IF_ERROR(result->Populate<ReturnT>( 1181 [&scalar](tensorflow::gtl::ArraySlice<int64> multi_index) { 1182 return scalar; 1183 })); 1184 1185 const Literal& evaluated_operand = 1186 parent_->GetEvaluatedLiteralFor(pad->operand(0)); 1187 1188 std::vector<int64> input_index(ShapeUtil::Rank(evaluated_operand.shape()), 1189 0); 1190 std::vector<int64> target_index(ShapeUtil::Rank(result->shape()), 0); 1191 1192 // Loop through each element of the operand, assign them to the 1193 // corresponding index of the resulting padded literal. 1194 const PaddingConfig& pad_config = pad->padding_config(); 1195 1196 auto func = [&](const std::vector<int64>& input_index) { 1197 for (auto i = 0; i < input_index.size(); ++i) { 1198 // Interior padding occurs logically before edge padding, so in the case 1199 // of negative edge padding elements are removed from the 1200 // interior-padded operand. 1201 target_index[i] = 1202 pad_config.dimensions(i).edge_padding_low() + 1203 input_index[i] * (pad_config.dimensions(i).interior_padding() + 1); 1204 1205 // Account for negative low and high padding: skip assignment if the 1206 // any target index is out of range. 1207 if (!(target_index[i] >= 0 && 1208 target_index[i] < pad->shape().dimensions(i))) { 1209 return true; 1210 } 1211 } 1212 result->Set<ReturnT>(target_index, 1213 evaluated_operand.Get<ReturnT>(input_index)); 1214 return true; 1215 }; 1216 1217 std::vector<int64> zero_base(evaluated_operand.shape().dimensions_size(), 1218 0); 1219 std::vector<int64> step(evaluated_operand.shape().dimensions_size(), 1); 1220 1221 ShapeUtil::ForEachIndex( 1222 evaluated_operand.shape(), zero_base, 1223 AsInt64Slice(evaluated_operand.shape().dimensions()), step, func); 1224 1225 parent_->evaluated_[pad] = std::move(result); 1226 return Status::OK(); 1227 } 1228 1229 Status HandleDynamicSlice(HloInstruction* dynamic_slice) override { 1230 auto operand = dynamic_slice->operand(0); 1231 auto start_indices = dynamic_slice->operand(1); 1232 auto result_shape = dynamic_slice->shape(); 1233 TF_ASSIGN_OR_RETURN(auto inferred_return_shape, 1234 ShapeInference::InferDynamicSliceShape( 1235 operand->shape(), start_indices->shape(), 1236 dynamic_slice->dynamic_slice_sizes())); 1237 TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) 1238 << "return shape is set to: " << ShapeUtil::HumanString(result_shape) 1239 << "but is inferred to be: " 1240 << ShapeUtil::HumanString(inferred_return_shape); 1241 TF_RET_CHECK( 1242 primitive_util::IsIntegralType(start_indices->shape().element_type())); 1243 1244 const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); 1245 const Literal& start_indices_literal = 1246 parent_->GetEvaluatedLiteralFor(start_indices); 1247 1248 switch (start_indices->shape().element_type()) { 1249 case S32: { 1250 TF_ASSIGN_OR_RETURN( 1251 parent_->evaluated_[dynamic_slice], 1252 DynamicSlice<int32>(operand_literal, start_indices_literal, 1253 result_shape)); 1254 } break; 1255 case S64: { 1256 TF_ASSIGN_OR_RETURN( 1257 parent_->evaluated_[dynamic_slice], 1258 DynamicSlice<int64>(operand_literal, start_indices_literal, 1259 result_shape)); 1260 } break; 1261 case U32: { 1262 TF_ASSIGN_OR_RETURN( 1263 parent_->evaluated_[dynamic_slice], 1264 DynamicSlice<uint32>(operand_literal, start_indices_literal, 1265 result_shape)); 1266 } break; 1267 case U64: { 1268 TF_ASSIGN_OR_RETURN( 1269 parent_->evaluated_[dynamic_slice], 1270 DynamicSlice<uint64>(operand_literal, start_indices_literal, 1271 result_shape)); 1272 } break; 1273 default: 1274 LOG(FATAL) << "HandleDynamicSlice: unhandled primitive type for " 1275 "start_indices: " 1276 << PrimitiveType_Name(start_indices->shape().element_type()); 1277 } 1278 1279 return Status::OK(); 1280 } 1281 1282 Status HandleDynamicUpdateSlice( 1283 HloInstruction* dynamic_update_slice) override { 1284 auto operand = dynamic_update_slice->operand(0); 1285 auto update = dynamic_update_slice->operand(1); 1286 auto start_indices = dynamic_update_slice->operand(2); 1287 auto result_shape = dynamic_update_slice->shape(); 1288 TF_ASSIGN_OR_RETURN( 1289 auto inferred_return_shape, 1290 ShapeInference::InferDynamicUpdateSliceShape( 1291 operand->shape(), update->shape(), start_indices->shape())); 1292 TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) 1293 << "return shape is set to: " << ShapeUtil::HumanString(result_shape) 1294 << "but is inferred to be: " 1295 << ShapeUtil::HumanString(inferred_return_shape); 1296 TF_RET_CHECK( 1297 primitive_util::IsIntegralType(start_indices->shape().element_type())); 1298 TF_RET_CHECK(ShapeUtil::Compatible(result_shape, operand->shape())); 1299 1300 const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); 1301 const Literal& update_literal = parent_->GetEvaluatedLiteralFor(update); 1302 const Literal& start_indices_literal = 1303 parent_->GetEvaluatedLiteralFor(start_indices); 1304 1305 switch (start_indices->shape().element_type()) { 1306 case S32: { 1307 TF_ASSIGN_OR_RETURN( 1308 parent_->evaluated_[dynamic_update_slice], 1309 DynamicUpdateSlice<int32>(operand_literal, update_literal, 1310 start_indices_literal)); 1311 } break; 1312 case S64: { 1313 TF_ASSIGN_OR_RETURN( 1314 parent_->evaluated_[dynamic_update_slice], 1315 DynamicUpdateSlice<int64>(operand_literal, update_literal, 1316 start_indices_literal)); 1317 } break; 1318 case U32: { 1319 TF_ASSIGN_OR_RETURN( 1320 parent_->evaluated_[dynamic_update_slice], 1321 DynamicUpdateSlice<uint32>(operand_literal, update_literal, 1322 start_indices_literal)); 1323 } break; 1324 case U64: { 1325 TF_ASSIGN_OR_RETURN( 1326 parent_->evaluated_[dynamic_update_slice], 1327 DynamicUpdateSlice<uint64>(operand_literal, update_literal, 1328 start_indices_literal)); 1329 } break; 1330 default: 1331 LOG(FATAL) << "HandleDynamicUpdateSlice: unhandled primitive type for " 1332 "start_indices: " 1333 << PrimitiveType_Name(start_indices->shape().element_type()); 1334 } 1335 1336 return Status::OK(); 1337 } 1338 1339 template <typename NativeT> 1340 StatusOr<std::unique_ptr<Literal>> MapImpl(HloInstruction* map) { 1341 auto operands = map->operands(); 1342 HloComputation* computation = map->to_apply(); 1343 1344 auto result = Literal::CreateFromShape(map->shape()); 1345 1346 HloEvaluator embedded_evaluator; 1347 TF_RETURN_IF_ERROR(result->Populate<ReturnT>( 1348 [&](tensorflow::gtl::ArraySlice<int64> multi_index) { 1349 std::vector<std::unique_ptr<Literal>> arg_literals; 1350 arg_literals.reserve(operands.size()); 1351 1352 // Construct scalar literal parameters to be passed to the map 1353 // computation. 1354 for (auto operand : operands) { 1355 const Literal& arg_literal = 1356 parent_->GetEvaluatedLiteralFor(operand); 1357 1358 auto curr_val = arg_literal.Get<NativeT>(multi_index); 1359 auto curr_val_literal = Literal::CreateR0<NativeT>(curr_val); 1360 1361 arg_literals.push_back(std::move(curr_val_literal)); 1362 } 1363 1364 std::unique_ptr<Literal> computed_result = 1365 embedded_evaluator 1366 .Evaluate<std::unique_ptr<Literal>>(*computation, 1367 arg_literals) 1368 .ConsumeValueOrDie(); 1369 // Clear visit states so that the we can use the evaluate again on 1370 // the same computation. 1371 embedded_evaluator.ResetVisitStates(); 1372 1373 return computed_result->Get<ReturnT>({}); 1374 })); 1375 return std::move(result); 1376 } 1377 1378 Status HandleMap(HloInstruction* map) override { 1379 switch (map->operand(0)->shape().element_type()) { 1380 case PRED: { 1381 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<bool>(map)); 1382 break; 1383 } 1384 case U8: { 1385 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint8>(map)); 1386 break; 1387 } 1388 case U32: { 1389 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint32>(map)); 1390 break; 1391 } 1392 case U64: { 1393 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint64>(map)); 1394 break; 1395 } 1396 case S8: { 1397 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int8>(map)); 1398 break; 1399 } 1400 case S32: { 1401 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int32>(map)); 1402 break; 1403 } 1404 case S64: { 1405 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int64>(map)); 1406 break; 1407 } 1408 case F16: { 1409 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], 1410 MapImpl<Eigen::half>(map)); 1411 break; 1412 } 1413 case F32: { 1414 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<float>(map)); 1415 break; 1416 } 1417 case F64: { 1418 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<double>(map)); 1419 break; 1420 } 1421 case C64: { 1422 TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<complex64>(map)); 1423 break; 1424 } 1425 default: 1426 LOG(FATAL) << "HandleMap: unhandled primitive type for " 1427 "input operand: " 1428 << PrimitiveType_Name( 1429 map->operand(0)->shape().element_type()); 1430 } 1431 1432 return Status::OK(); 1433 } 1434 1435 Status HandleReduce(HloInstruction* reduce) override { 1436 auto arg = reduce->operand(0); 1437 auto init_value = reduce->operand(1); 1438 tensorflow::gtl::ArraySlice<int64> dimensions(reduce->dimensions()); 1439 HloComputation* function = reduce->to_apply(); 1440 TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) == 1441 ShapeUtil::Rank(arg->shape()) - dimensions.size()); 1442 TF_ASSIGN_OR_RETURN(auto inferred_return_shape, 1443 ShapeInference::InferReduceShape( 1444 /*arg=*/arg->shape(), 1445 /*init_value=*/init_value->shape(), 1446 /*dimensions_to_reduce=*/dimensions, 1447 /*to_apply=*/function->ComputeProgramShape())); 1448 TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) 1449 << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape()) 1450 << "but is inferred to be: " 1451 << ShapeUtil::HumanString(inferred_return_shape); 1452 1453 const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg); 1454 VLOG(3) << "HandleReduce arg_literal: " << arg_literal.ToString(); 1455 const Literal& init_literal = parent_->GetEvaluatedLiteralFor(init_value); 1456 VLOG(3) << "HandleReduce init_literal: " << init_literal.ToString(); 1457 TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); 1458 auto init_scalar = init_literal.Get<ReturnT>({}); 1459 1460 auto result = Literal::CreateFromShape(reduce->shape()); 1461 1462 const auto arg_dimensions = AsInt64Slice(arg_literal.shape().dimensions()); 1463 std::vector<int64> arg_dim_steps(arg_dimensions.size()); 1464 std::vector<int64> arg_dim_counts(arg_dimensions.size()); 1465 for (const int64 dim : dimensions) { 1466 arg_dim_steps[dim] = 1; 1467 arg_dim_counts[dim] = arg_dimensions[dim]; 1468 } 1469 1470 // Create mapping from result index to arg index. 1471 const int64 result_rank = ShapeUtil::Rank(result->shape()); 1472 int64 result_dim = 0; 1473 std::vector<int64> result_to_arg_index(result_rank); 1474 for (int64 i = 0; i < arg_dimensions.size(); ++i) { 1475 if (arg_dim_steps[i] == 0) { 1476 result_to_arg_index[result_dim] = i; 1477 ++result_dim; 1478 } 1479 } 1480 1481 HloEvaluator embedded_evaluator; 1482 // For each resulting dimension, calculate and assign computed value. 1483 TF_RETURN_IF_ERROR(result->Populate<ReturnT>( 1484 [&](tensorflow::gtl::ArraySlice<int64> multi_index) { 1485 ReturnT result_val = init_scalar; 1486 1487 std::vector<int64> base(arg_dimensions.size()); 1488 for (int64 i = 0; i < multi_index.size(); ++i) { 1489 base[result_to_arg_index[i]] = multi_index[i]; 1490 } 1491 1492 auto func = [&](const std::vector<int64>& input_index) { 1493 auto curr_val = arg_literal.Get<ReturnT>(input_index); 1494 1495 // Evaluate computation with specified literal operands. 1496 auto curr_val_literal = Literal::CreateR0<ReturnT>(curr_val); 1497 auto result_val_literal = Literal::CreateR0<ReturnT>(result_val); 1498 std::vector<const Literal*> args = {curr_val_literal.get(), 1499 result_val_literal.get()}; 1500 1501 std::unique_ptr<Literal> computed_result = 1502 embedded_evaluator.Evaluate<const Literal*>(*function, args) 1503 .ConsumeValueOrDie(); 1504 // Clear visit states so that the we can use the evaluate again on 1505 // the same computation. 1506 embedded_evaluator.ResetVisitStates(); 1507 1508 // Assign computed result to result_val. 1509 result_val = computed_result->Get<ReturnT>({}); 1510 1511 return true; 1512 }; 1513 1514 ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, 1515 arg_dim_steps, func); 1516 1517 return result_val; 1518 })); 1519 1520 parent_->evaluated_[reduce] = std::move(result); 1521 return Status::OK(); 1522 } 1523 1524 Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override { 1525 auto operand = select_and_scatter->operand(0); 1526 auto source = select_and_scatter->operand(1); 1527 const Window& window = select_and_scatter->window(); 1528 1529 const Literal& init_literal = 1530 parent_->GetEvaluatedLiteralFor(select_and_scatter->operand(2)); 1531 TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); 1532 auto init_scalar = init_literal.Get<ReturnT>({}); 1533 1534 auto result = Literal::CreateFromShape(select_and_scatter->shape()); 1535 1536 // Initialize result array with the init value. 1537 TF_RETURN_IF_ERROR(result->Populate<ReturnT>( 1538 [&](tensorflow::gtl::ArraySlice<int64> output_index) { 1539 return init_scalar; 1540 })); 1541 1542 std::vector<int64> window_dimension_sizes; 1543 for (const auto& window_dimension : window.dimensions()) { 1544 window_dimension_sizes.push_back(window_dimension.size()); 1545 } 1546 const Shape window_shape = ShapeUtil::MakeShape( 1547 operand->shape().element_type(), window_dimension_sizes); 1548 1549 HloComputation* select = select_and_scatter->select(); 1550 HloComputation* scatter = select_and_scatter->scatter(); 1551 1552 const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); 1553 const Literal& source_literal = parent_->GetEvaluatedLiteralFor(source); 1554 1555 int64 rank = ShapeUtil::Rank(operand_literal.shape()); 1556 1557 HloEvaluator embedded_evaluator; 1558 DimensionVector source_index(rank); 1559 1560 std::fill(source_index.begin(), source_index.end(), 0); 1561 do { 1562 // For each element in `source`, we place a window in `operand`. For each 1563 // window placement, we iterate inside the window twice: 1564 // 1565 // 1. Find the selected index by applying `select` function to all 1566 // elements. E.g., If the `select` function is GreaterEqual, the first 1567 // iteration through the window finds the biggest value and returns its 1568 // index. 1569 // 1570 // 2. Using the selected index, scatter value from `source` to result. We 1571 // do this by iterating through the window, and compare each index with 1572 // the selected index. 1573 tensorflow::gtl::optional<ReturnT> selected_val; 1574 tensorflow::gtl::optional<std::vector<int64>> selected_index; 1575 1576 IterateThroughWindow( 1577 window_shape, window, operand_literal.shape(), source_index, 1578 [&](const std::vector<int64>& operand_index) { 1579 auto curr_val = operand_literal.Get<ReturnT>(operand_index); 1580 if (!selected_val) { 1581 selected_val = curr_val; 1582 selected_index = operand_index; 1583 } 1584 const auto curr_val_literal = Literal::CreateR0<ReturnT>(curr_val); 1585 const auto selected_val_literal = 1586 Literal::CreateR0<ReturnT>(*selected_val); 1587 1588 const std::vector<const Literal*> args = { 1589 curr_val_literal.get(), selected_val_literal.get()}; 1590 std::unique_ptr<Literal> computed_result = 1591 embedded_evaluator.Evaluate<const Literal*>(*select, args) 1592 .ConsumeValueOrDie(); 1593 bool selected = computed_result->Get<bool>({}); 1594 if (selected) { 1595 selected_val = curr_val; 1596 selected_index = operand_index; 1597 } 1598 embedded_evaluator.ResetVisitStates(); 1599 }); 1600 1601 IterateThroughWindow( 1602 window_shape, window, operand_literal.shape(), source_index, 1603 [&](const std::vector<int64>& operand_index) { 1604 if (std::equal(operand_index.begin(), operand_index.end(), 1605 selected_index->begin())) { 1606 auto source = source_literal.Get<ReturnT>(source_index); 1607 auto scattered = result->Get<ReturnT>(operand_index); 1608 const auto source_literal = Literal::CreateR0<ReturnT>(source); 1609 const auto scattered_literal = 1610 Literal::CreateR0<ReturnT>(scattered); 1611 1612 const std::vector<const Literal*> args = { 1613 source_literal.get(), scattered_literal.get()}; 1614 std::unique_ptr<Literal> computed_result = 1615 embedded_evaluator.Evaluate<const Literal*>(*scatter, args) 1616 .ConsumeValueOrDie(); 1617 result->Set(operand_index, computed_result->Get<ReturnT>({})); 1618 // Clear visit states so that the we can use the evaluator again 1619 // on the same computation. 1620 embedded_evaluator.ResetVisitStates(); 1621 } 1622 }); 1623 } while (IndexUtil::BumpIndices(source->shape(), &source_index)); 1624 1625 parent_->evaluated_[select_and_scatter] = std::move(result); 1626 return Status::OK(); 1627 } 1628 1629 Status HandleReduceWindow(HloInstruction* reduce_window) override { 1630 auto operand = reduce_window->operand(0); 1631 const Window& window = reduce_window->window(); 1632 HloComputation* function = reduce_window->to_apply(); 1633 TF_ASSIGN_OR_RETURN( 1634 auto inferred_return_shape, 1635 ShapeInference::InferReduceWindowShape( 1636 /*operand_shape=*/reduce_window->operand(0)->shape(), 1637 /*init_value=*/reduce_window->operand(1)->shape(), window, 1638 /*to_apply_shape=*/function->ComputeProgramShape())); 1639 TF_RET_CHECK( 1640 ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape)) 1641 << "return shape is set to: " 1642 << ShapeUtil::HumanStringWithLayout(reduce_window->shape()) 1643 << "but is inferred to be: " 1644 << ShapeUtil::HumanStringWithLayout(inferred_return_shape); 1645 1646 const Literal& operand_literal = 1647 parent_->GetEvaluatedLiteralFor(reduce_window->operand(0)); 1648 VLOG(3) << "HandleReduceWindow arg_literal: " << operand_literal.ToString(); 1649 const Literal& init_literal = 1650 parent_->GetEvaluatedLiteralFor(reduce_window->operand(1)); 1651 VLOG(3) << "HandleReduceWindow init_literal: " << init_literal.ToString(); 1652 TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); 1653 auto init_scalar = init_literal.Get<ReturnT>({}); 1654 1655 auto result = Literal::CreateFromShape(reduce_window->shape()); 1656 1657 // Creates a Shape object from window, for iteration below. 1658 std::vector<int64> window_dimension_sizes; 1659 for (const auto& window_dimension : window.dimensions()) { 1660 window_dimension_sizes.push_back(window_dimension.size()); 1661 } 1662 const Shape window_shape = ShapeUtil::MakeShape( 1663 operand->shape().element_type(), window_dimension_sizes); 1664 1665 DimensionVector window_index(window.dimensions_size()); 1666 DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); 1667 1668 HloEvaluator embedded_evaluator; 1669 // For each resulting dimension, calculate and assign computed value. 1670 TF_RETURN_IF_ERROR(result->Populate<ReturnT>( 1671 [&](tensorflow::gtl::ArraySlice<int64> output_index) { 1672 ReturnT result_val = init_scalar; 1673 1674 std::fill(window_index.begin(), window_index.end(), 0); 1675 std::fill(operand_index.begin(), operand_index.end(), 0); 1676 1677 IterateThroughWindow( 1678 window_shape, window, operand_literal.shape(), output_index, 1679 [&](const std::vector<int64>& operand_index) { 1680 auto curr_val = operand_literal.Get<ReturnT>(operand_index); 1681 1682 // Evaluate computation with specified literal operands. 1683 const auto curr_val_literal = 1684 Literal::CreateR0<ReturnT>(curr_val); 1685 const auto result_val_literal = 1686 Literal::CreateR0<ReturnT>(result_val); 1687 const std::vector<const Literal*> args = { 1688 curr_val_literal.get(), result_val_literal.get()}; 1689 std::unique_ptr<Literal> computed_result = 1690 embedded_evaluator.Evaluate<const Literal*>(*function, args) 1691 .ConsumeValueOrDie(); 1692 1693 // Clear visit states so that the we can use the evaluate again 1694 // on the same computation. 1695 embedded_evaluator.ResetVisitStates(); 1696 1697 result_val = computed_result->Get<ReturnT>({}); 1698 }); 1699 1700 return result_val; 1701 })); 1702 1703 parent_->evaluated_[reduce_window] = std::move(result); 1704 return Status::OK(); 1705 } 1706 1707 Status HandleSlice(HloInstruction* slice) override { 1708 auto operand = slice->operand(0); 1709 const Shape& shape = slice->shape(); 1710 TF_ASSIGN_OR_RETURN(auto inferred_return_shape, 1711 ShapeInference::InferSliceShape( 1712 operand->shape(), slice->slice_starts(), 1713 slice->slice_limits(), slice->slice_strides())); 1714 TF_RET_CHECK(ShapeUtil::Compatible(shape, inferred_return_shape)) 1715 << "return shape set to: " << ShapeUtil::HumanString(shape) 1716 << " but is inferred to be: " 1717 << ShapeUtil::HumanString(inferred_return_shape); 1718 1719 const int64 rank = ShapeUtil::Rank(operand->shape()); 1720 const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); 1721 auto func = [&](tensorflow::gtl::ArraySlice<int64> out_index) { 1722 DimensionVector operand_index(rank); 1723 for (int64 i = 0; i < rank; ++i) { 1724 operand_index[i] = 1725 slice->slice_starts(i) + out_index[i] * slice->slice_strides(i); 1726 } 1727 return operand_literal.Get<ReturnT>(operand_index); 1728 }; 1729 1730 auto result = Literal::CreateFromDimensions( 1731 shape.element_type(), AsInt64Slice(shape.dimensions())); 1732 TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func)); 1733 parent_->evaluated_[slice] = std::move(result); 1734 return Status::OK(); 1735 } 1736 1737 template <typename NativeT, typename std::enable_if<std::is_floating_point< 1738 NativeT>::value>::type* = nullptr> 1739 Status HandleSin(HloInstruction* sin) { 1740 TF_ASSIGN_OR_RETURN(parent_->evaluated_[sin], 1741 ElementWiseUnaryOp(sin, [](ElementwiseT elem_operand) { 1742 return std::sin(elem_operand); 1743 })); 1744 return Status::OK(); 1745 } 1746 1747 template < 1748 typename NativeT, 1749 typename std::enable_if<std::is_integral<NativeT>::value || 1750 is_complex_t<NativeT>::value>::type* = nullptr> 1751 Status HandleSin(HloInstruction* sin) { 1752 return InvalidArgument("Unsupported type for Sin"); 1753 } 1754 1755 Status HandleSin(HloInstruction* sin) override { 1756 return HandleSin<ElementwiseT>(sin); 1757 } 1758 1759 template <typename NativeT, typename std::enable_if<std::is_floating_point< 1760 NativeT>::value>::type* = nullptr> 1761 Status HandleCos(HloInstruction* cos) { 1762 TF_ASSIGN_OR_RETURN(parent_->evaluated_[cos], 1763 ElementWiseUnaryOp(cos, [](ElementwiseT elem_operand) { 1764 return std::cos(elem_operand); 1765 })); 1766 return Status::OK(); 1767 } 1768 1769 template < 1770 typename NativeT, 1771 typename std::enable_if<std::is_integral<NativeT>::value || 1772 is_complex_t<NativeT>::value>::type* = nullptr> 1773 Status HandleCos(HloInstruction* cos) { 1774 return InvalidArgument("Unsupported type for Cos"); 1775 } 1776 1777 Status HandleCos(HloInstruction* cos) override { 1778 return HandleCos<ElementwiseT>(cos); 1779 } 1780 1781 template <typename NativeT, typename std::enable_if<std::is_same< 1782 float, NativeT>::value>::type* = nullptr> 1783 Status HandleReducePrecision(HloInstruction* reduce_precision) { 1784 TF_ASSIGN_OR_RETURN( 1785 parent_->evaluated_[reduce_precision], 1786 ElementWiseUnaryOp(reduce_precision, [reduce_precision]( 1787 ElementwiseT elem) { 1788 uint32_t value_as_int = tensorflow::bit_cast<uint32_t>(elem); 1789 const uint32_t mantissa_bits = reduce_precision->mantissa_bits(); 1790 const uint32_t exponent_bits = reduce_precision->exponent_bits(); 1791 1792 // Code is based on the CPU/GPU implementation in LLVM-emitting code. 1793 // 1794 // Bits in float type: 1795 // mantissa : bits [0:22] 1796 // exponent : bits [23:30] 1797 // sign : bits [31] 1798 if (mantissa_bits < 23) { 1799 const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits); 1800 1801 // Compute rounding bias for round-to-nearest with ties to even. 1802 // This is equal to a base value of 0111... plus one bit if the last 1803 // remaining mantissa bit is 1. 1804 const uint32_t base_rounding_bias = 1805 (last_mantissa_bit_mask >> 1) - 1; 1806 const uint32_t x_last_mantissa_bit = 1807 (value_as_int & last_mantissa_bit_mask) >> (23 - mantissa_bits); 1808 const uint32_t x_rounding_bias = 1809 x_last_mantissa_bit + base_rounding_bias; 1810 1811 // Add rounding bias, and mask out truncated bits. Note that the 1812 // case where adding the rounding bias overflows into the exponent 1813 // bits is correct; the non-masked mantissa bits will all be zero, 1814 // and the exponent will be incremented by one. 1815 const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); 1816 value_as_int = value_as_int + x_rounding_bias; 1817 value_as_int = value_as_int & truncation_mask; 1818 } 1819 if (exponent_bits < 8) { 1820 // Masks for f32 values. 1821 const uint32_t f32_sign_bit_mask = 1u << 31; 1822 const uint32_t f32_exp_bits_mask = 0xffu << 23; 1823 1824 // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the 1825 // most- significant bit -- is equal to 1.0f for all exponent sizes. 1826 // Adding 2^(n-1)-1 to this gives us the highest non-infinite 1827 // exponent for a bit- size of n, and subtracting 2^(n-1)-1 from 1828 // this gives us the lowest' exponent (corresponding to 0.0f). 1829 // 1830 // Thus, the f32 exponent corresponding to the highest non-infinite 1831 // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 1832 // exponent corresponding to the lowest exponent for a bit size of n 1833 // is (2^7-1) - 2^(n-1)-1. 1834 // 1835 // Note that we have already checked that exponents_bits >= 1. 1836 const uint32_t f32_exponent_bias = (1 << 7) - 1; 1837 const uint32_t reduced_exponent_bias = 1838 (1 << (exponent_bits - 1)) - 1; 1839 const uint32_t reduced_max_exponent = 1840 f32_exponent_bias + reduced_exponent_bias; 1841 const uint32_t reduced_min_exponent = 1842 f32_exponent_bias - reduced_exponent_bias; 1843 1844 // Do we overflow or underflow? 1845 const uint32_t x_exponent = value_as_int & f32_exp_bits_mask; 1846 const bool x_overflows = x_exponent > (reduced_max_exponent << 23); 1847 const bool x_underflows = 1848 x_exponent <= (reduced_min_exponent << 23); 1849 1850 // Compute appropriately-signed values of zero and infinity. 1851 const uint32_t x_signed_zero = value_as_int & f32_sign_bit_mask; 1852 const uint32_t x_signed_inf = x_signed_zero | f32_exp_bits_mask; 1853 1854 // Force to zero or infinity if overflow or underflow. (Note that 1855 // this truncates all denormal values to zero, rather than rounding 1856 // them.) 1857 value_as_int = x_overflows ? x_signed_inf : value_as_int; 1858 value_as_int = x_underflows ? x_signed_zero : value_as_int; 1859 } 1860 1861 float reduced_result = tensorflow::bit_cast<float>(value_as_int); 1862 if (std::isnan(elem)) { 1863 reduced_result = mantissa_bits > 0 1864 ? elem 1865 : std::numeric_limits<float>::infinity(); 1866 } 1867 return reduced_result; 1868 })); 1869 return Status::OK(); 1870 } 1871 1872 template <typename NativeT, typename std::enable_if<std::is_same< 1873 double, NativeT>::value>::type* = nullptr> 1874 Status HandleReducePrecision(HloInstruction* reduce_precision) { 1875 return InvalidArgument("Double not supported for reduce precision"); 1876 } 1877 1878 template < 1879 typename NativeT, 1880 typename std::enable_if<std::is_integral<NativeT>::value || 1881 is_complex_t<NativeT>::value>::type* = nullptr> 1882 Status HandleReducePrecision(HloInstruction* reduce_precision) { 1883 return InvalidArgument("Unsupported type for reduce precision"); 1884 } 1885 1886 Status HandleReducePrecision(HloInstruction* reduce_precision) override { 1887 return HandleReducePrecision<ElementwiseT>(reduce_precision); 1888 } 1889 1890 private: 1891 template <typename IndexT> 1892 StatusOr<std::unique_ptr<Literal>> DynamicSlice( 1893 const Literal& operand_literal, const Literal& start_indices_literal, 1894 const Shape& result_shape) { 1895 auto start_indices_typed = start_indices_literal.data<IndexT>(); 1896 std::vector<int64> start(start_indices_typed.begin(), 1897 start_indices_typed.end()); 1898 1899 std::vector<int64> operand_indices(start.size()); 1900 1901 auto result = Literal::CreateFromShape(result_shape); 1902 TF_RETURN_IF_ERROR(result->Populate<ReturnT>( 1903 [&](tensorflow::gtl::ArraySlice<int64> multi_index) { 1904 for (int64 i = 0; i < operand_indices.size(); ++i) { 1905 CHECK_GE(multi_index[i] + start[i], 0); 1906 // Mod is only used here to be consistent with the existing 1907 // backends' behavior. 1908 operand_indices[i] = (multi_index[i] + start[i]) % 1909 operand_literal.shape().dimensions(i); 1910 } 1911 1912 auto result = operand_literal.Get<ReturnT>(operand_indices); 1913 return result; 1914 })); 1915 1916 return std::move(result); 1917 } 1918 1919 template <typename IndexT> 1920 StatusOr<std::unique_ptr<Literal>> DynamicUpdateSlice( 1921 const Literal& operand_literal, const Literal& update_literal, 1922 const Literal& start_indices_literal) { 1923 auto start_indices_typed = start_indices_literal.data<IndexT>(); 1924 const std::vector<int64> start(start_indices_typed.begin(), 1925 start_indices_typed.end()); 1926 1927 auto result = operand_literal.CloneToUnique(); 1928 std::vector<int64> result_index(ShapeUtil::Rank(result->shape()), 0); 1929 1930 auto func = [&](const std::vector<int64>& update_index) { 1931 std::transform(update_index.begin(), update_index.end(), start.begin(), 1932 result_index.begin(), std::plus<int64>()); 1933 1934 result->Set<ReturnT>(result_index, 1935 update_literal.Get<ReturnT>(update_index)); 1936 return true; 1937 }; 1938 1939 std::vector<int64> base(update_literal.shape().dimensions_size(), 0); 1940 std::vector<int64> step(update_literal.shape().dimensions_size(), 1); 1941 ShapeUtil::ForEachIndex(update_literal.shape(), base, 1942 AsInt64Slice(update_literal.shape().dimensions()), 1943 step, func); 1944 1945 return std::move(result); 1946 } 1947 1948 StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOp( 1949 HloInstruction* instruction, 1950 const std::function<ElementwiseT(ElementwiseT)>& unary_op) { 1951 const Literal& operand_literal = 1952 parent_->GetEvaluatedLiteralFor(instruction->operand(0)); 1953 TF_ASSIGN_OR_RETURN( 1954 auto result_literal, 1955 (ElementWiseUnaryOpImpl<ReturnT, ReturnT>( 1956 instruction, ConvertUnaryFunction(unary_op), operand_literal))); 1957 1958 return std::move(result_literal); 1959 } 1960 1961 StatusOr<std::unique_ptr<Literal>> ElementWiseBinaryOp( 1962 HloInstruction* instruction, 1963 const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>& 1964 binary_op) { 1965 const auto shape = instruction->shape(); 1966 const auto* lhs = instruction->operand(0); 1967 const auto* rhs = instruction->operand(1); 1968 1969 // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast 1970 // is removed. 1971 if (!(ShapeUtil::SameDimensions(shape, rhs->shape()) && 1972 ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { 1973 return Unimplemented( 1974 "Implicit broadcasting is currently unsupported in HLO evaluator " 1975 "Shape Mismatch: %s vs %s vs %s: ", 1976 ShapeUtil::HumanString(shape).c_str(), 1977 ShapeUtil::HumanString(lhs->shape()).c_str(), 1978 ShapeUtil::HumanString(rhs->shape()).c_str()); 1979 } 1980 1981 const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); 1982 const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); 1983 1984 auto result = Literal::CreateFromShape(shape); 1985 1986 TF_RETURN_IF_ERROR(result->Populate<ReturnT>( 1987 [&](tensorflow::gtl::ArraySlice<int64> multi_index) { 1988 return ConvertBinaryFunction(binary_op)( 1989 lhs_literal.Get<ReturnT>(multi_index), 1990 rhs_literal.Get<ReturnT>(multi_index)); 1991 })); 1992 return std::move(result); 1993 } 1994 1995 template <typename LhsType, typename RhsType, typename EhsType> 1996 StatusOr<std::unique_ptr<Literal>> ElementwiseTernaryOp( 1997 HloInstruction* instruction, 1998 const std::function<ReturnT(LhsType, RhsType, EhsType)>& ternary_op) { 1999 const auto shape = instruction->shape(); 2000 const auto* lhs = instruction->operand(0); 2001 const auto* rhs = instruction->operand(1); 2002 const auto* ehs = instruction->operand(2); 2003 2004 // TODO(b/35950897, b/27796129): add DCHECK back once implicit 2005 // broadcast is removed. 2006 if (!(ShapeUtil::SameDimensions(shape, lhs->shape()) && 2007 ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()) && 2008 ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()))) { 2009 return Unimplemented( 2010 "Implicit broadcasting is currently unsupported in HLO evaluator " 2011 "Shape Mismatch: %s vs %s vs %s vs %s: ", 2012 ShapeUtil::HumanString(shape).c_str(), 2013 ShapeUtil::HumanString(lhs->shape()).c_str(), 2014 ShapeUtil::HumanString(rhs->shape()).c_str(), 2015 ShapeUtil::HumanString(ehs->shape()).c_str()); 2016 } 2017 2018 const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); 2019 const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); 2020 const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); 2021 2022 auto result = Literal::CreateFromShape(shape); 2023 2024 TF_RETURN_IF_ERROR(result->Populate<ReturnT>( 2025 [&](tensorflow::gtl::ArraySlice<int64> multi_index) { 2026 return ternary_op(lhs_literal.Get<LhsType>(multi_index), 2027 rhs_literal.Get<RhsType>(multi_index), 2028 ehs_literal.Get<EhsType>(multi_index)); 2029 })); 2030 2031 return std::move(result); 2032 } 2033 2034 HloEvaluator* parent_; 2035 }; // class HloEvaluator::TypedVisitor 2036 2037 HloEvaluator::HloEvaluator() { 2038 typed_visitors_[PRED] = MakeUnique<TypedVisitor<bool>>(this); 2039 typed_visitors_[U8] = MakeUnique<TypedVisitor<uint8>>(this); 2040 typed_visitors_[U16] = MakeUnique<FunctionVisitor>([](HloInstruction*) { 2041 return Unimplemented("HloEvaluator: unhandled primitive type: U16."); 2042 }); 2043 typed_visitors_[U32] = MakeUnique<TypedVisitor<uint32>>(this); 2044 typed_visitors_[U64] = MakeUnique<TypedVisitor<uint64>>(this); 2045 typed_visitors_[S8] = MakeUnique<TypedVisitor<int8>>(this); 2046 typed_visitors_[S16] = MakeUnique<FunctionVisitor>([](HloInstruction*) { 2047 return Unimplemented("HloEvaluator: unhandled primitive type: S16."); 2048 }); 2049 typed_visitors_[S32] = MakeUnique<TypedVisitor<int32>>(this); 2050 typed_visitors_[S64] = MakeUnique<TypedVisitor<int64>>(this); 2051 typed_visitors_[F16] = MakeUnique<TypedVisitor<Eigen::half, float>>(this); 2052 typed_visitors_[F32] = MakeUnique<TypedVisitor<float>>(this); 2053 typed_visitors_[F64] = MakeUnique<TypedVisitor<double>>(this); 2054 typed_visitors_[C64] = MakeUnique<TypedVisitor<complex64>>(this); 2055 2056 // Most of the evaluator computations we use don't support BF16 (e.g., 2057 // std::ceil, std::tanh). To make evaluator work with BF16, we set all 2058 // elementwise computations to be done in F32 and do BF16<->F32 conversion 2059 // around the input and the output of the computations. 2060 typed_visitors_[BF16] = MakeUnique<TypedVisitor<bfloat16, float>>(this); 2061 typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) { 2062 return Unimplemented("HloEvaluator: unhandled primitive type: TUPLE."); 2063 }); 2064 typed_visitors_[OPAQUE] = MakeUnique<FunctionVisitor>([](HloInstruction*) { 2065 return Unimplemented("HloEvaluator: unhandled primitive type: OPAQUE."); 2066 }); 2067 } 2068 2069 template <typename LiteralPtr> 2070 StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate( 2071 const HloModule& module, 2072 tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals) { 2073 XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString()); 2074 2075 evaluated_.clear(); 2076 arg_literals_.clear(); 2077 for (const auto& literal_ptr : arg_literals) { 2078 arg_literals_.push_back(&*literal_ptr); 2079 } 2080 2081 TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this)); 2082 2083 return GetEvaluatedLiteralFor(module.entry_computation()->root_instruction()) 2084 .CloneToUnique(); 2085 } 2086 2087 template <typename LiteralPtr> 2088 StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate( 2089 const HloComputation& computation, 2090 tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals) { 2091 XLA_VLOG_LINES( 2092 2, "HloEvaluator::Evaluate computation:\n" + computation.ToString()); 2093 2094 evaluated_.clear(); 2095 arg_literals_.clear(); 2096 for (const auto& literal_ptr : arg_literals) { 2097 arg_literals_.push_back(&*literal_ptr); 2098 } 2099 2100 TF_RETURN_IF_ERROR(computation.Accept(this)); 2101 return GetEvaluatedLiteralFor(computation.root_instruction()).CloneToUnique(); 2102 } 2103 2104 template <typename LiteralPtr> 2105 StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate( 2106 HloInstruction* instruction, 2107 tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals) { 2108 TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); 2109 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); 2110 2111 evaluated_.clear(); 2112 arg_literals_.clear(); 2113 for (const auto& literal_ptr : arg_literals) { 2114 arg_literals_.push_back(&*literal_ptr); 2115 } 2116 2117 // Evaluate operands of Parameter type against the input literals which 2118 // caches the evaluated literal results. 2119 for (const auto operand : instruction->operands()) { 2120 if (operand->opcode() == HloOpcode::kParameter) { 2121 const Literal* input_literal = arg_literals_[operand->parameter_number()]; 2122 VLOG(2) << "Parameter operand evaluated to: " 2123 << input_literal->ToString(); 2124 TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape())); 2125 2126 evaluated_[operand] = input_literal->CloneToUnique(); 2127 } 2128 } 2129 2130 TF_RETURN_IF_ERROR(Preprocess(instruction)); 2131 TF_RETURN_IF_ERROR(instruction->Visit(this)); 2132 TF_RETURN_IF_ERROR(Postprocess(instruction)); 2133 return GetEvaluatedLiteralFor(instruction).CloneToUnique(); 2134 } 2135 2136 StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate( 2137 HloInstruction* instruction) { 2138 if (instruction->opcode() == HloOpcode::kParameter) { 2139 return tensorflow::errors::FailedPrecondition( 2140 "Cannot evaluate a parameter."); 2141 } 2142 if (!hlo_query::AllOperandsAreConstants(*instruction)) { 2143 return tensorflow::errors::FailedPrecondition( 2144 "Not all operands are constants."); 2145 } 2146 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); 2147 2148 arg_literals_.clear(); 2149 evaluated_.clear(); 2150 2151 TF_RETURN_IF_ERROR(Preprocess(instruction)); 2152 TF_RETURN_IF_ERROR(instruction->Visit(this)); 2153 TF_RETURN_IF_ERROR(Postprocess(instruction)); 2154 return GetEvaluatedLiteralFor(instruction).CloneToUnique(); 2155 } 2156 2157 std::unique_ptr<Literal> HloEvaluator::TryEvaluate( 2158 HloInstruction* instruction) { 2159 auto result_or = Evaluate(instruction); 2160 if (!result_or.ok()) { 2161 VLOG(1) << "TryEvaluate failed:" << result_or.status(); 2162 return nullptr; 2163 } 2164 2165 return result_or.ConsumeValueOrDie(); 2166 } 2167 2168 StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions( 2169 const HloInstruction* instruction, 2170 const std::unordered_map<const HloInstruction*, const Literal*>& 2171 substitutions) { 2172 std::vector<std::unique_ptr<HloInstruction>> owned_operands; 2173 for (const HloInstruction* operand : instruction->operands()) { 2174 auto it = substitutions.find(operand); 2175 if (it == substitutions.end()) { 2176 owned_operands.push_back(operand->Clone()); 2177 } else { 2178 owned_operands.push_back( 2179 HloInstruction::CreateConstant(it->second->CloneToUnique())); 2180 } 2181 } 2182 2183 std::vector<HloInstruction*> operands; 2184 operands.reserve(owned_operands.size()); 2185 for (auto& operand : owned_operands) { 2186 operands.push_back(operand.get()); 2187 } 2188 2189 std::unique_ptr<HloInstruction> cloned_instruction = 2190 instruction->CloneWithNewOperands(instruction->shape(), operands); 2191 auto result = Evaluate(cloned_instruction.get()); 2192 2193 // Clean up our cloned instructions before returning. 2194 cloned_instruction->DetachFromOperands(); 2195 for (auto& operand : owned_operands) { 2196 operand->DetachFromOperands(); 2197 } 2198 2199 return result; 2200 } 2201 2202 Status HloEvaluator::HandleParameter(HloInstruction* parameter) { 2203 CHECK_LT(parameter->parameter_number(), arg_literals_.size()); 2204 const Literal* input_literal = arg_literals_[parameter->parameter_number()]; 2205 VLOG(2) << "Parameter evaluated to: " << input_literal->ToString(); 2206 DCHECK(ShapeUtil::Equal(parameter->shape(), input_literal->shape())) 2207 << "parameter shape is: " << ShapeUtil::HumanString(parameter->shape()) 2208 << ", but input literal shape is: " 2209 << ShapeUtil::HumanString(input_literal->shape()); 2210 2211 evaluated_[parameter] = input_literal->CloneToUnique(); 2212 return Status::OK(); 2213 } 2214 2215 Status HloEvaluator::HandleConstant(HloInstruction*) { return Status::OK(); } 2216 2217 Status HloEvaluator::HandleReshape(HloInstruction* reshape) { 2218 TF_ASSIGN_OR_RETURN( 2219 evaluated_[reshape], 2220 GetEvaluatedLiteralFor(reshape->operand(0)) 2221 .Reshape(AsInt64Slice(reshape->shape().dimensions()))); 2222 return Status::OK(); 2223 } 2224 2225 Status HloEvaluator::HandleTranspose(HloInstruction* transpose) { 2226 evaluated_[transpose] = GetEvaluatedLiteralFor(transpose->operand(0)) 2227 .Transpose(transpose->dimensions()); 2228 return Status::OK(); 2229 } 2230 2231 Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { 2232 tensorflow::gtl::ArraySlice<HloInstruction*> operands( 2233 concatenate->operands()); 2234 // The result concatenate dimension is going to be the sum of all 2235 // concatenate dimensions of the operands taking part of the operation. 2236 const Shape& reference_shape = operands[0]->shape(); 2237 CHECK(!ShapeUtil::IsTuple(reference_shape)); 2238 const int64 rank = ShapeUtil::Rank(reference_shape); 2239 const int64 concat_dim = concatenate->dimensions()[0]; 2240 CHECK_GE(concat_dim, 0); 2241 CHECK_LT(concat_dim, rank); 2242 2243 DimensionVector concat_dimensions(reference_shape.dimensions().begin(), 2244 reference_shape.dimensions().end()); 2245 2246 for (int64 i = 1; i < operands.size(); ++i) { 2247 const Shape& operand_shape = operands[i]->shape(); 2248 CHECK(!ShapeUtil::IsTuple(operand_shape)); 2249 // Accumulate the concat dimension from all tensors taking part to the 2250 // operation. 2251 concat_dimensions[concat_dim] += 2252 ShapeUtil::GetDimension(operand_shape, concat_dim); 2253 } 2254 2255 auto result_literal = Literal::CreateFromDimensions( 2256 reference_shape.element_type(), concat_dimensions); 2257 DimensionVector source_indices(rank, 0); 2258 DimensionVector dest_indices(concat_dimensions.size(), 0); 2259 2260 for (auto operand : operands) { 2261 const Shape& operand_shape = operand->shape(); 2262 TF_RETURN_IF_ERROR(result_literal->CopySliceFrom( 2263 GetEvaluatedLiteralFor(operand), source_indices, dest_indices, 2264 AsInt64Slice(operand_shape.dimensions()))); 2265 dest_indices[concat_dim] += 2266 ShapeUtil::GetDimension(operand_shape, concat_dim); 2267 } 2268 2269 evaluated_[concatenate] = std::move(result_literal); 2270 return Status::OK(); 2271 } 2272 2273 Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) { 2274 auto operand = is_finite->operand(0); 2275 if (!ShapeUtil::ElementIsFloating(operand->shape())) { 2276 return InvalidArgument( 2277 "expected element type in shape to be float for IsFinite op, got: %s", 2278 PrimitiveType_Name(operand->shape().element_type()).c_str()); 2279 } 2280 2281 switch (operand->shape().element_type()) { 2282 case F16: 2283 return Unimplemented("unhandled primitive type: F16."); 2284 case F32: { 2285 auto result_or = ElementWiseUnaryOpImpl<bool, float>( 2286 is_finite, 2287 [](float elem_operand) { return std::isfinite(elem_operand); }, 2288 GetEvaluatedLiteralFor(operand)); 2289 TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or)); 2290 break; 2291 } 2292 case F64: { 2293 auto result_or = ElementWiseUnaryOpImpl<bool, double>( 2294 is_finite, 2295 [](double elem_operand) { return std::isfinite(elem_operand); }, 2296 GetEvaluatedLiteralFor(operand)); 2297 TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or)); 2298 break; 2299 } 2300 default: 2301 LOG(FATAL) << "HandleIsFinite: unknown/unhandled primitive type: " 2302 << PrimitiveType_Name(operand->shape().element_type()); 2303 } 2304 2305 return Status::OK(); 2306 } 2307 2308 Status HloEvaluator::HandleCompare(HloInstruction* compare) { 2309 HloOpcode opcode = compare->opcode(); 2310 auto lhs = compare->operand(0); 2311 auto rhs = compare->operand(1); 2312 // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is 2313 // removed. 2314 if (!(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) && 2315 ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { 2316 return Unimplemented( 2317 "Implicit broadcasting is currently unsupported in HLO evaluator " 2318 "Shape Mismatch: %s vs %s vs %s", 2319 ShapeUtil::HumanString(compare->shape()).c_str(), 2320 ShapeUtil::HumanString(lhs->shape()).c_str(), 2321 ShapeUtil::HumanString(rhs->shape()).c_str()); 2322 } 2323 2324 TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type()); 2325 2326 const Literal& lhs_literal = GetEvaluatedLiteralFor(lhs); 2327 const Literal& rhs_literal = GetEvaluatedLiteralFor(rhs); 2328 2329 // Note here we switch on the operand's type. 2330 switch (lhs->shape().element_type()) { 2331 case PRED: { 2332 TF_ASSIGN_OR_RETURN( 2333 evaluated_[compare], 2334 Compare<bool>(compare->shape(), opcode, lhs_literal, rhs_literal)); 2335 } break; 2336 case U8: { 2337 TF_ASSIGN_OR_RETURN( 2338 evaluated_[compare], 2339 Compare<uint8>(compare->shape(), opcode, lhs_literal, rhs_literal)); 2340 } break; 2341 case U16: 2342 return Unimplemented("unhandled primitive type: U16."); 2343 case U32: { 2344 TF_ASSIGN_OR_RETURN( 2345 evaluated_[compare], 2346 Compare<uint32>(compare->shape(), opcode, lhs_literal, rhs_literal)); 2347 } break; 2348 case U64: { 2349 TF_ASSIGN_OR_RETURN( 2350 evaluated_[compare], 2351 Compare<uint64>(compare->shape(), opcode, lhs_literal, rhs_literal)); 2352 } break; 2353 case S8: { 2354 TF_ASSIGN_OR_RETURN( 2355 evaluated_[compare], 2356 Compare<int8>(compare->shape(), opcode, lhs_literal, rhs_literal)); 2357 } break; 2358 case S16: 2359 return Unimplemented("unhandled primitive type: S16."); 2360 case S32: { 2361 TF_ASSIGN_OR_RETURN( 2362 evaluated_[compare], 2363 Compare<int32>(compare->shape(), opcode, lhs_literal, rhs_literal)); 2364 } break; 2365 case S64: { 2366 TF_ASSIGN_OR_RETURN( 2367 evaluated_[compare], 2368 Compare<int64>(compare->shape(), opcode, lhs_literal, rhs_literal)); 2369 } break; 2370 case F16: 2371 return Unimplemented("unhandled primitive type: F16."); 2372 case F32: { 2373 TF_ASSIGN_OR_RETURN( 2374 evaluated_[compare], 2375 Compare<float>(compare->shape(), opcode, lhs_literal, rhs_literal)); 2376 } break; 2377 case F64: { 2378 TF_ASSIGN_OR_RETURN( 2379 evaluated_[compare], 2380 Compare<double>(compare->shape(), opcode, lhs_literal, rhs_literal)); 2381 } break; 2382 case C64: { 2383 TF_ASSIGN_OR_RETURN(evaluated_[compare], 2384 Compare<complex64>(compare->shape(), opcode, 2385 lhs_literal, rhs_literal)); 2386 } break; 2387 default: 2388 LOG(FATAL) << "HandleCompare: unknown primitive type: " 2389 << PrimitiveType_Name(lhs->shape().element_type()); 2390 } 2391 2392 return Status::OK(); 2393 } 2394 2395 Status HloEvaluator::HandleTuple(HloInstruction* tuple) { 2396 std::vector<const Literal*> operand_literals; 2397 for (auto operand : tuple->operands()) { 2398 operand_literals.push_back(&GetEvaluatedLiteralFor(operand)); 2399 } 2400 2401 evaluated_[tuple] = Literal::MakeTuple(operand_literals); 2402 return Status::OK(); 2403 } 2404 2405 Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) { 2406 const auto result_shape = get_tuple_element->shape(); 2407 const int64 index = get_tuple_element->tuple_index(); 2408 2409 auto operand = get_tuple_element->operand(0); 2410 TF_ASSIGN_OR_RETURN( 2411 auto inferred_return_shape, 2412 ShapeInference::InferGetTupleElementShape(operand->shape(), index)); 2413 TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) 2414 << "return shape set to: " << ShapeUtil::HumanString(result_shape) 2415 << " but is inferred to be: " 2416 << ShapeUtil::HumanString(inferred_return_shape); 2417 2418 const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand); 2419 2420 evaluated_[get_tuple_element] = MakeUnique<Literal>( 2421 ShapeUtil::GetTupleElementShape(operand->shape(), index)); 2422 return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal, 2423 /*dest_shape_index=*/{}, 2424 /*src_shape_index=*/{index}); 2425 } 2426 2427 Status HloEvaluator::HandleCopy(HloInstruction* copy) { 2428 TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape())); 2429 2430 auto result = GetEvaluatedLiteralFor(copy->operand(0)).CloneToUnique(); 2431 evaluated_[copy] = std::move(result); 2432 return Status::OK(); 2433 } 2434 2435 Status HloEvaluator::Preprocess(HloInstruction* hlo) { 2436 VLOG(2) << "About to visit HLO: " << hlo->ToString(); 2437 return Status::OK(); 2438 } 2439 2440 Status HloEvaluator::Postprocess(HloInstruction* hlo) { 2441 VLOG(2) << "Finished visiting " << hlo->ToString() 2442 << "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString(); 2443 return Status::OK(); 2444 } 2445 2446 // Explicit instantiation of templatized Evaluate* methods. 2447 // 2448 template StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate< 2449 const Literal*>(const HloModule& module, 2450 tensorflow::gtl::ArraySlice<const Literal*> arg_literals); 2451 template StatusOr<std::unique_ptr<Literal>> 2452 HloEvaluator::Evaluate<std::unique_ptr<Literal>>( 2453 const HloModule& module, 2454 tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>> arg_literals); 2455 2456 template StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate< 2457 const Literal*>(const HloComputation& computation, 2458 tensorflow::gtl::ArraySlice<const Literal*> arg_literals); 2459 template StatusOr<std::unique_ptr<Literal>> 2460 HloEvaluator::Evaluate<std::unique_ptr<Literal>>( 2461 const HloComputation& computation, 2462 tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>> arg_literals); 2463 2464 template StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate< 2465 const Literal*>(HloInstruction* instruction, 2466 tensorflow::gtl::ArraySlice<const Literal*> arg_literals); 2467 template StatusOr<std::unique_ptr<Literal>> 2468 HloEvaluator::Evaluate<std::unique_ptr<Literal>>( 2469 HloInstruction* instruction, 2470 tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>> arg_literals); 2471 2472 } // namespace xla 2473