Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
     17 
     18 #include <algorithm>
     19 #include <memory>
     20 #include <string>
     21 #include <vector>
     22 
     23 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
     24 #include "llvm/IR/BasicBlock.h"
     25 #include "llvm/IR/Instructions.h"
     26 #include "llvm/IR/Intrinsics.h"
     27 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
     28 #include "tensorflow/compiler/xla/primitive_util.h"
     29 #include "tensorflow/compiler/xla/service/hlo_module.h"
     30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     31 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
     32 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
     33 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
     34 #include "tensorflow/compiler/xla/shape_util.h"
     35 #include "tensorflow/compiler/xla/status_macros.h"
     36 #include "tensorflow/compiler/xla/statusor.h"
     37 #include "tensorflow/compiler/xla/types.h"
     38 #include "tensorflow/compiler/xla/util.h"
     39 #include "tensorflow/compiler/xla/xla_data.pb.h"
     40 #include "tensorflow/core/lib/random/random.h"
     41 #include "tensorflow/core/lib/strings/strcat.h"
     42 #include "tensorflow/core/platform/logging.h"
     43 #include "tensorflow/core/platform/types.h"
     44 
     45 namespace xla {
     46 
     47 using llvm_ir::AsStringRef;
     48 using llvm_ir::IrArray;
     49 using llvm_ir::IrName;
     50 using llvm_ir::SetToFirstInsertPoint;
     51 using tensorflow::strings::StrCat;
     52 
     53 namespace {
     54 
     55 llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits,
     56                                       int64 mantissa_bits,
     57                                       llvm::IRBuilder<>* ir_builder) {
     58   // Integer and float types for casting and constant generation.
     59   llvm::Type* float_type = x->getType();
     60   llvm::IntegerType* int_type = ir_builder->getInt32Ty();
     61 
     62   // Cast the input value to an integer for bitwise manipulation.
     63   llvm::Value* x_as_int = ir_builder->CreateBitCast(x, int_type);
     64 
     65   if (mantissa_bits < 23) {
     66     // Last remaining mantissa bit.
     67     const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits);
     68 
     69     // Compute rounding bias for round-to-nearest with ties to even.  This is
     70     // equal to a base value of 0111... plus one bit if the last remaining
     71     // mantissa bit is 1.
     72     const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1;
     73     llvm::Value* x_last_mantissa_bit = ir_builder->CreateLShr(
     74         ir_builder->CreateAnd(
     75             x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)),
     76         (23 - mantissa_bits));
     77     llvm::Value* x_rounding_bias = ir_builder->CreateAdd(
     78         x_last_mantissa_bit,
     79         llvm::ConstantInt::get(int_type, base_rounding_bias));
     80 
     81     // Add rounding bias, and mask out truncated bits.  Note that the case
     82     // where adding the rounding bias overflows into the exponent bits is
     83     // correct; the non-masked mantissa bits will all be zero, and the
     84     // exponent will be incremented by one.
     85     const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1);
     86     x_as_int = ir_builder->CreateAdd(x_as_int, x_rounding_bias);
     87     x_as_int = ir_builder->CreateAnd(
     88         x_as_int, llvm::ConstantInt::get(int_type, truncation_mask));
     89   }
     90 
     91   if (exponent_bits < 8) {
     92     // Masks for f32 values.
     93     const uint32_t f32_sign_bit_mask = 1u << 31;
     94     const uint32_t f32_exp_bits_mask = 0xffu << 23;
     95 
     96     // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most-
     97     // significant bit -- is equal to 1.0f for all exponent sizes.  Adding
     98     // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit-
     99     // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest'
    100     // exponent (corresponding to 0.0f).
    101     //
    102     // Thus, the f32 exponent corresponding to the highest non-infinite
    103     // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
    104     // exponent corresponding to the lowest exponent for a bit size of n is
    105     // (2^7-1) - 2^(n-1)-1.
    106     //
    107     // Note that we have already checked that exponents_bits >= 1.
    108     const uint32_t f32_exponent_bias = (1 << 7) - 1;
    109     const uint32_t reduced_exponent_bias = (1 << (exponent_bits - 1)) - 1;
    110     const uint32_t reduced_max_exponent =
    111         f32_exponent_bias + reduced_exponent_bias;
    112     const uint32_t reduced_min_exponent =
    113         f32_exponent_bias - reduced_exponent_bias;
    114 
    115     // Do we overflow or underflow?
    116     llvm::Value* x_exponent = ir_builder->CreateAnd(
    117         x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
    118     llvm::Value* x_overflows = ir_builder->CreateICmpUGT(
    119         x_exponent,
    120         llvm::ConstantInt::get(int_type, reduced_max_exponent << 23));
    121     llvm::Value* x_underflows = ir_builder->CreateICmpULE(
    122         x_exponent,
    123         llvm::ConstantInt::get(int_type, reduced_min_exponent << 23));
    124 
    125     // Compute appropriately-signed values of zero and infinity.
    126     llvm::Value* x_signed_zero = ir_builder->CreateAnd(
    127         x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask));
    128     llvm::Value* x_signed_inf = ir_builder->CreateOr(
    129         x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
    130 
    131     // Force to zero or infinity if overflow or underflow.  (Note that this
    132     // truncates all denormal values to zero, rather than rounding them.)
    133     x_as_int = ir_builder->CreateSelect(x_overflows, x_signed_inf, x_as_int);
    134     x_as_int = ir_builder->CreateSelect(x_underflows, x_signed_zero, x_as_int);
    135   }
    136 
    137   // Cast the result back to a floating-point type.
    138   llvm::Value* result = ir_builder->CreateBitCast(x_as_int, float_type);
    139 
    140   // Correct result for NaN inputs.
    141   //
    142   // The exponent handling will "normalize" NaN values to infinities, which is
    143   // undesirable (except in the case with no mantissa bits, in which case it
    144   // is mandatory).  This logic also handles cases where mantissa-rounding
    145   // causes a NaN's mantissa to overflow into the exponent bits, which would
    146   // otherwise create an erroneous zero value.
    147   //
    148   // If the fast-math flags are set to assume no NaNs, the comparison is likely
    149   // to be optimized away, so there's no point in even emitting it.
    150   if (!ir_builder->getFastMathFlags().noNaNs()) {
    151     llvm::Value* x_is_nan = ir_builder->CreateFCmpUNO(x, x);
    152 
    153     if (mantissa_bits > 0) {
    154       result = ir_builder->CreateSelect(x_is_nan, x, result);
    155     } else {
    156       result = ir_builder->CreateSelect(
    157           x_is_nan, llvm::ConstantFP::getInfinity(float_type), result);
    158     }
    159   }
    160   return result;
    161 }
    162 
    163 llvm::Value* EmitF32ToBF16(llvm::Value* f32_value,
    164                            llvm::IRBuilder<>* ir_builder) {
    165   auto reduced_precision = EmitReducePrecisionFloat(
    166       f32_value,
    167       /*exponent_bits=*/primitive_util::kBFloat16ExponentBits,
    168       /*mantissa_bits=*/primitive_util::kBFloat16MantissaBits, ir_builder);
    169   auto as_int32 =
    170       ir_builder->CreateBitCast(reduced_precision, ir_builder->getInt32Ty());
    171   auto shifted = ir_builder->CreateLShr(as_int32, 16);
    172   auto truncated = ir_builder->CreateTrunc(shifted, ir_builder->getInt16Ty());
    173   return ir_builder->CreateBitCast(truncated, ir_builder->getInt16Ty());
    174 }
    175 
    176 llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value,
    177                            llvm::IRBuilder<>* ir_builder) {
    178   auto as_int16 =
    179       ir_builder->CreateBitCast(bf16_value, ir_builder->getInt16Ty());
    180   auto as_int32 = ir_builder->CreateZExt(as_int16, ir_builder->getInt32Ty());
    181   auto shifted = ir_builder->CreateShl(as_int32, 16);
    182   return ir_builder->CreateBitCast(shifted, ir_builder->getFloatTy());
    183 }
    184 
    185 llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value,
    186                                     PrimitiveType from_type,
    187                                     PrimitiveType to_type, llvm::Module* module,
    188                                     llvm::IRBuilder<>* ir_builder) {
    189   if (primitive_util::IsSignedIntegralType(from_type)) {
    190     return ir_builder->CreateSIToFP(
    191         integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module));
    192   } else {
    193     CHECK(primitive_util::IsUnsignedIntegralType(from_type) ||
    194           from_type == PRED);
    195     return ir_builder->CreateUIToFP(
    196         integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module));
    197   }
    198 }
    199 
    200 }  // namespace
    201 
    202 StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp(
    203     const HloInstruction* op, llvm::Value* operand_value) const {
    204   if (op->opcode() == HloOpcode::kCopy) {
    205     return operand_value;
    206   } else if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
    207              op->operand(0)->shape().element_type() == PRED) {
    208     return EmitIntegerUnaryOp(op, operand_value);
    209   } else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) {
    210     return EmitComplexUnaryOp(op, operand_value);
    211   } else {
    212     return EmitFloatUnaryOp(op, operand_value);
    213   }
    214 }
    215 
    216 StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
    217     const HloInstruction* op, llvm::Value* operand_value) const {
    218   switch (op->opcode()) {
    219     case HloOpcode::kConvert: {
    220       PrimitiveType from_type = op->operand(0)->shape().element_type();
    221       PrimitiveType to_type = op->shape().element_type();
    222       CHECK(primitive_util::IsIntegralType(from_type) || from_type == PRED);
    223       if (from_type == to_type) {
    224         return operand_value;
    225       }
    226       if (primitive_util::IsIntegralType(to_type)) {
    227         return ir_builder_->CreateIntCast(
    228             operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_),
    229             primitive_util::IsSignedIntegralType(to_type));
    230       }
    231       if (primitive_util::IsFloatingPointType(to_type)) {
    232         if (to_type == BF16) {
    233           return EmitF32ToBF16(
    234               EmitIntegralToFloating(operand_value, from_type, F32, module_,
    235                                      ir_builder_),
    236               ir_builder_);
    237         }
    238         return EmitIntegralToFloating(operand_value, from_type, to_type,
    239                                       module_, ir_builder_);
    240       }
    241       if (primitive_util::IsComplexType(to_type)) {
    242         auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(
    243             primitive_util::ComplexComponentType(to_type), module_);
    244         if (primitive_util::IsSignedIntegralType(from_type)) {
    245           return EmitComposeComplex(
    246               op,
    247               ir_builder_->CreateSIToFP(operand_value, to_ir_component_type),
    248               nullptr);
    249         }
    250         if (primitive_util::IsUnsignedIntegralType(from_type) ||
    251             from_type == PRED) {
    252           return EmitComposeComplex(
    253               op,
    254               ir_builder_->CreateUIToFP(operand_value, to_ir_component_type),
    255               nullptr);
    256         }
    257       }
    258       return Unimplemented("conversion from primitive type %s to %s",
    259                            PrimitiveType_Name(from_type).c_str(),
    260                            PrimitiveType_Name(to_type).c_str());
    261     }
    262     case HloOpcode::kBitcastConvert: {
    263       PrimitiveType from_type = op->operand(0)->shape().element_type();
    264       PrimitiveType to_type = op->shape().element_type();
    265       CHECK(primitive_util::IsIntegralType(from_type));
    266       if (from_type == to_type) {
    267         return operand_value;
    268       }
    269       if (primitive_util::BitWidth(from_type) ==
    270           primitive_util::BitWidth(to_type)) {
    271         return ir_builder_->CreateBitCast(
    272             operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
    273       }
    274       return InvalidArgument(
    275           "bitcast conversion from primitive type %s to %s with unequal "
    276           "bit-widths (%u versus %u) ",
    277           PrimitiveType_Name(from_type).c_str(),
    278           PrimitiveType_Name(to_type).c_str(),
    279           primitive_util::BitWidth(from_type),
    280           primitive_util::BitWidth(to_type));
    281     }
    282     case HloOpcode::kAbs: {
    283       bool is_signed =
    284           primitive_util::IsSignedIntegralType(op->shape().element_type());
    285       if (is_signed) {
    286         auto type =
    287             llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
    288         auto zero = llvm::ConstantInt::get(type, 0);
    289         auto cmp = ir_builder_->CreateICmpSGE(operand_value, zero);
    290         return ir_builder_->CreateSelect(cmp, operand_value,
    291                                          ir_builder_->CreateNeg(operand_value));
    292       } else {
    293         return operand_value;
    294       }
    295     }
    296     case HloOpcode::kSign: {
    297       bool is_signed =
    298           primitive_util::IsSignedIntegralType(op->shape().element_type());
    299       auto type =
    300           llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
    301       auto zero = llvm::ConstantInt::get(type, 0);
    302       auto cmp = ir_builder_->CreateICmpEQ(operand_value, zero);
    303       if (is_signed) {
    304         auto ashr = ir_builder_->CreateAShr(operand_value,
    305                                             type->getIntegerBitWidth() - 1);
    306         return ir_builder_->CreateSelect(cmp, zero,
    307                                          ir_builder_->CreateOr(ashr, 1));
    308       } else {
    309         return ir_builder_->CreateSelect(cmp, zero,
    310                                          llvm::ConstantInt::get(type, 1));
    311       }
    312     }
    313     case HloOpcode::kNegate:
    314       return ir_builder_->CreateNeg(operand_value);
    315     case HloOpcode::kNot: {
    316       auto type = op->shape().element_type();
    317       if (type == PRED) {
    318         // It is not sufficient to just call CreateNot() here because a PRED
    319         // is represented as an i8 and the truth value is stored only in the
    320         // bottom bit.
    321         return ir_builder_->CreateZExt(
    322             ir_builder_->CreateNot(ir_builder_->CreateTrunc(
    323                 operand_value, ir_builder_->getInt1Ty())),
    324             llvm_ir::PrimitiveTypeToIrType(PRED, module_));
    325       } else if (primitive_util::IsIntegralType(type)) {
    326         return ir_builder_->CreateNot(operand_value);
    327       }
    328       return Unimplemented("unary op Not is not defined for type '%d'", type);
    329     }
    330     default:
    331       return Unimplemented("unary integer op '%s'",
    332                            HloOpcodeString(op->opcode()).c_str());
    333   }
    334 }
    335 
    336 StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
    337     const HloInstruction* op, llvm::Value* operand_value) const {
    338   switch (op->opcode()) {
    339     case HloOpcode::kConvert: {
    340       PrimitiveType from_type = op->operand(0)->shape().element_type();
    341       PrimitiveType to_type = op->shape().element_type();
    342       CHECK(primitive_util::IsFloatingPointType(from_type));
    343       if (from_type == to_type) {
    344         return operand_value;
    345       }
    346       if (primitive_util::IsComplexType(to_type)) {
    347         PrimitiveType to_component_type =
    348             primitive_util::ComplexComponentType(to_type);
    349         if (from_type == to_component_type) {
    350           return EmitComposeComplex(op, operand_value, nullptr);
    351         }
    352         return EmitComposeComplex(
    353             op,
    354             ir_builder_->CreateFPCast(
    355                 operand_value,
    356                 llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)),
    357             nullptr);
    358       }
    359       if (from_type == BF16) {
    360         TF_RET_CHECK(to_type != BF16);
    361         operand_value = EmitBF16ToF32(operand_value, ir_builder_);
    362         from_type = F32;
    363         if (from_type == to_type) {
    364           return operand_value;
    365         }
    366       }
    367       if (from_type == F32 && to_type == BF16) {
    368         return EmitF32ToBF16(operand_value, ir_builder_);
    369       }
    370       if (primitive_util::IsFloatingPointType(to_type)) {
    371         return ir_builder_->CreateFPCast(
    372             operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
    373       }
    374       if (primitive_util::IsSignedIntegralType(to_type)) {
    375         return ir_builder_->CreateFPToSI(
    376             operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
    377       }
    378       if (primitive_util::IsUnsignedIntegralType(to_type)) {
    379         return ir_builder_->CreateFPToUI(
    380             operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
    381       }
    382       return Unimplemented("unhandled conversion operation: %s => %s",
    383                            PrimitiveType_Name(from_type).c_str(),
    384                            PrimitiveType_Name(to_type).c_str());
    385     }
    386     case HloOpcode::kBitcastConvert: {
    387       PrimitiveType from_type = op->operand(0)->shape().element_type();
    388       PrimitiveType to_type = op->shape().element_type();
    389       CHECK(primitive_util::IsFloatingPointType(from_type));
    390       if (from_type == to_type) {
    391         return operand_value;
    392       }
    393       if (primitive_util::BitWidth(from_type) ==
    394           primitive_util::BitWidth(to_type)) {
    395         return ir_builder_->CreateBitCast(
    396             operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
    397       }
    398       return InvalidArgument(
    399           "bitcast conversion from primitive type %s to %s with unequal "
    400           "bit-widths (%u versus %u) ",
    401           PrimitiveType_Name(from_type).c_str(),
    402           PrimitiveType_Name(to_type).c_str(),
    403           primitive_util::BitWidth(from_type),
    404           primitive_util::BitWidth(to_type));
    405     }
    406     case HloOpcode::kExp:
    407       return EmitExp(op->shape().element_type(), operand_value);
    408     case HloOpcode::kLog:
    409       return EmitLog(op->shape().element_type(), operand_value);
    410     case HloOpcode::kCos:
    411       return EmitCos(op->shape().element_type(), operand_value);
    412     case HloOpcode::kSin:
    413       return EmitSin(op->shape().element_type(), operand_value);
    414     case HloOpcode::kFloor:
    415       return llvm_ir::EmitCallToIntrinsic(
    416           llvm::Intrinsic::floor, {operand_value}, {operand_value->getType()},
    417           ir_builder_);
    418     case HloOpcode::kCeil:
    419       return llvm_ir::EmitCallToIntrinsic(
    420           llvm::Intrinsic::ceil, {operand_value}, {operand_value->getType()},
    421           ir_builder_);
    422     case HloOpcode::kAbs:
    423       return llvm_ir::EmitCallToIntrinsic(
    424           llvm::Intrinsic::fabs, {operand_value}, {operand_value->getType()},
    425           ir_builder_);
    426     case HloOpcode::kRoundNearestAfz:
    427       return llvm_ir::EmitCallToIntrinsic(
    428           llvm::Intrinsic::round, {operand_value}, {operand_value->getType()},
    429           ir_builder_);
    430     case HloOpcode::kSign: {
    431       // TODO(b/32151903): Ensure consistent sign behavior for -0.0.
    432       auto type = operand_value->getType();
    433       auto zero = llvm::ConstantFP::get(type, 0.0);
    434       auto oeq = ir_builder_->CreateFCmpOEQ(operand_value, zero);
    435       auto olt = ir_builder_->CreateFCmpOLT(operand_value, zero);
    436       return ir_builder_->CreateSelect(
    437           oeq, zero,
    438           ir_builder_->CreateSelect(olt, llvm::ConstantFP::get(type, -1.0),
    439                                     llvm::ConstantFP::get(type, 1.0)));
    440     }
    441     case HloOpcode::kIsFinite: {
    442       // (x == x) && abs(x) != inf
    443       auto type = operand_value->getType();
    444       auto equal_self =
    445           ir_builder_->CreateFCmpOEQ(operand_value, operand_value);
    446       auto abs_value = llvm_ir::EmitCallToIntrinsic(
    447           llvm::Intrinsic::fabs, {operand_value}, {type}, ir_builder_);
    448       auto infinity = llvm::ConstantFP::getInfinity(type);
    449       auto not_infinite = ir_builder_->CreateFCmpONE(abs_value, infinity);
    450       auto result_i1 = ir_builder_->CreateAnd(equal_self, not_infinite);
    451       return ir_builder_->CreateZExt(
    452           result_i1, llvm_ir::PrimitiveTypeToIrType(PRED, module_));
    453     }
    454     case HloOpcode::kNegate:
    455       return ir_builder_->CreateFNeg(operand_value);
    456     default:
    457       return Unimplemented("unary floating-point op '%s'",
    458                            HloOpcodeString(op->opcode()).c_str());
    459   }
    460 }
    461 
    462 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
    463     const HloInstruction* op, llvm::Value* operand_value) const {
    464   PrimitiveType input_type = op->operand(0)->shape().element_type();
    465   PrimitiveType component_type =
    466       primitive_util::IsComplexType(input_type)
    467           ? primitive_util::ComplexComponentType(input_type)
    468           : input_type;
    469   switch (op->opcode()) {
    470     case HloOpcode::kLog: {
    471       // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a)
    472       auto a = EmitExtractReal(operand_value);
    473       auto b = EmitExtractImag(operand_value);
    474       llvm::Type* llvm_ty = a->getType();
    475       auto sum_sq = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a),
    476                                             ir_builder_->CreateFMul(b, b));
    477       TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq));
    478       TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a));
    479       auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
    480       return EmitComposeComplex(
    481           op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle);
    482     }
    483     case HloOpcode::kConvert: {
    484       PrimitiveType from_type = op->operand(0)->shape().element_type();
    485       TF_RET_CHECK(primitive_util::IsComplexType(from_type));
    486       PrimitiveType to_type = op->shape().element_type();
    487       TF_RET_CHECK(primitive_util::IsComplexType(to_type));
    488       if (from_type == to_type) {
    489         return operand_value;
    490       }
    491       PrimitiveType to_component_type =
    492           primitive_util::ComplexComponentType(to_type);
    493       auto to_ir_component_type =
    494           llvm_ir::PrimitiveTypeToIrType(to_component_type, module_);
    495       return EmitComposeComplex(
    496           op,
    497           ir_builder_->CreateFPCast(EmitExtractReal(operand_value),
    498                                     to_ir_component_type),
    499           ir_builder_->CreateFPCast(EmitExtractImag(operand_value),
    500                                     to_ir_component_type));
    501     }
    502     case HloOpcode::kExp: {
    503       // e^(a+bi) = e^a*(cos(b)+sin(b)i)
    504       TF_ASSIGN_OR_RETURN(
    505           auto exp_a, EmitExp(component_type, EmitExtractReal(operand_value)));
    506       TF_ASSIGN_OR_RETURN(
    507           auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
    508       TF_ASSIGN_OR_RETURN(
    509           auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
    510       return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b),
    511                                 ir_builder_->CreateFMul(exp_a, sin_b));
    512     }
    513     case HloOpcode::kCos: {
    514       // cos(z) = .5(e^(iz) + e^(-iz))
    515       // cos(a+bi) = .5(e^(-b+ai) + e^(b-ai))
    516       // now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
    517       // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(-a)+sin(-a)i))
    518       // cos(-x) = cos(x) and sin(-x) = -sin(x), so
    519       // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(a)-sin(a)i))
    520       //           = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b))
    521       auto a = EmitExtractReal(operand_value);
    522       auto b = EmitExtractImag(operand_value);
    523       auto type = a->getType();
    524       TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b));
    525       auto half_exp_b =
    526           ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b);
    527       auto half_exp_neg_b =
    528           ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
    529       TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
    530       TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
    531       return EmitComposeComplex(
    532           op,
    533           ir_builder_->CreateFMul(
    534               cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)),
    535           ir_builder_->CreateFMul(
    536               sin_a, ir_builder_->CreateFSub(half_exp_neg_b, half_exp_b)));
    537     }
    538     case HloOpcode::kSin: {
    539       // sin(z) = .5i(e^(-iz) - e^(iz))
    540       // sin(a+bi) = .5i(e^(-i(a+bi)) - e^(i(a+bi)))
    541       //           = .5i(e^(b-ai) - e^(-b+ai))
    542       // now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
    543       // sin(a+bi) = 0.5i(e^b*(cos(-a)+sin(-a)i) - e^-b*(cos(a)+sin(a)i))
    544       //           = 0.5(e^b*(cos(-a)i-sin(-a)) - e^-b*(cos(a)i-sin(a)))
    545       // cos(-x) = cos(x) and sin(-x) = -sin(x), so
    546       //           = 0.5(e^b*(cos(a)i+sin(a)) - e^-b*(cos(a)i-sin(a)))
    547       //           = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b)
    548       auto a = EmitExtractReal(operand_value);
    549       auto b = EmitExtractImag(operand_value);
    550       auto type = a->getType();
    551       TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b));
    552       auto half_exp_b =
    553           ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b);
    554       auto half_exp_neg_b =
    555           ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
    556       TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
    557       TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
    558       return EmitComposeComplex(
    559           op,
    560           ir_builder_->CreateFMul(
    561               sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)),
    562           ir_builder_->CreateFMul(
    563               cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b)));
    564     }
    565     case HloOpcode::kTanh: {
    566       /*
    567       tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x))
    568       e^(a+bi) = e^a*(cos(b)+sin(b)i)
    569       so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) /
    570               (((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a))
    571       cos(b)=cos(-b), sin(-b)=-sin(b)
    572       so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) /
    573               (((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a))
    574              =(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) /
    575               (cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a))
    576              =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) /
    577               (cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a))
    578       This is a complex division, so we can multiply by denom_conj/denom_conj
    579              =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) *
    580               (cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) /
    581               ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
    582              =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) +
    583                i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) /
    584               ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
    585       */
    586       auto a = EmitExtractReal(operand_value);
    587       auto b = EmitExtractImag(operand_value);
    588       TF_ASSIGN_OR_RETURN(auto exp_a, EmitExp(component_type, a));
    589       TF_ASSIGN_OR_RETURN(auto cos_b, EmitCos(component_type, b));
    590       TF_ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b));
    591       auto exp_neg_a = ir_builder_->CreateFDiv(
    592           llvm::ConstantFP::get(exp_a->getType(), 1), exp_a);
    593       auto exp_2a_minus_exp_neg_2a = ir_builder_->CreateFSub(
    594           ir_builder_->CreateFMul(exp_a, exp_a),
    595           ir_builder_->CreateFMul(exp_neg_a, exp_neg_a));
    596       auto cos_b_sq = ir_builder_->CreateFMul(cos_b, cos_b);
    597       auto sin_b_sq = ir_builder_->CreateFMul(sin_b, sin_b);
    598       auto real_num = ir_builder_->CreateFAdd(
    599           ir_builder_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a),
    600           ir_builder_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a));
    601       auto cos_b_sin_b = ir_builder_->CreateFMul(cos_b, sin_b);
    602       auto exp_a_plus_exp_neg_a = ir_builder_->CreateFAdd(exp_a, exp_neg_a);
    603       auto exp_a_plus_exp_neg_a_sq =
    604           ir_builder_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a);
    605       auto exp_a_minus_exp_neg_a = ir_builder_->CreateFSub(exp_a, exp_neg_a);
    606       auto exp_a_minus_exp_neg_a_sq =
    607           ir_builder_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a);
    608       auto imag_num = ir_builder_->CreateFMul(
    609           cos_b_sin_b, ir_builder_->CreateFSub(exp_a_plus_exp_neg_a_sq,
    610                                                exp_a_minus_exp_neg_a_sq));
    611       auto denom = ir_builder_->CreateFAdd(
    612           ir_builder_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq),
    613           ir_builder_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq));
    614       return EmitComposeComplex(op, ir_builder_->CreateFDiv(real_num, denom),
    615                                 ir_builder_->CreateFDiv(imag_num, denom));
    616     }
    617     case HloOpcode::kAbs: {
    618       auto sum_sq = ir_builder_->CreateFAdd(
    619           ir_builder_->CreateFMul(EmitExtractReal(operand_value),
    620                                   EmitExtractReal(operand_value)),
    621           ir_builder_->CreateFMul(EmitExtractImag(operand_value),
    622                                   EmitExtractImag(operand_value)));
    623       return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {sum_sq},
    624                                           {sum_sq->getType()}, ir_builder_);
    625     }
    626     case HloOpcode::kSign: {  // Sign(c) = c / |c|
    627       auto sum_sq = ir_builder_->CreateFAdd(
    628           ir_builder_->CreateFMul(EmitExtractReal(operand_value),
    629                                   EmitExtractReal(operand_value)),
    630           ir_builder_->CreateFMul(EmitExtractImag(operand_value),
    631                                   EmitExtractImag(operand_value)));
    632       auto cplx_abs = llvm_ir::EmitCallToIntrinsic(
    633           llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, ir_builder_);
    634       auto type = cplx_abs->getType();
    635       auto zero = llvm::ConstantFP::get(type, 0.0);
    636       auto oeq = ir_builder_->CreateFCmpOEQ(cplx_abs, zero);
    637       return ir_builder_->CreateSelect(
    638           oeq, EmitComposeComplex(op, zero, zero),
    639           EmitComposeComplex(
    640               op,
    641               ir_builder_->CreateFDiv(EmitExtractReal(operand_value), cplx_abs),
    642               ir_builder_->CreateFDiv(EmitExtractImag(operand_value),
    643                                       cplx_abs)));
    644     }
    645     case HloOpcode::kNegate:
    646       return EmitComposeComplex(
    647           op, ir_builder_->CreateFNeg(EmitExtractReal(operand_value)),
    648           ir_builder_->CreateFNeg(EmitExtractImag(operand_value)));
    649     case HloOpcode::kReal:
    650       return EmitExtractReal(operand_value);
    651     case HloOpcode::kImag:
    652       return EmitExtractImag(operand_value);
    653     default:
    654       return Unimplemented("unary complex op '%s'",
    655                            HloOpcodeString(op->opcode()).c_str());
    656   }
    657 }
    658 
    659 StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp(
    660     const HloInstruction* op, llvm::Value* lhs_value,
    661     llvm::Value* rhs_value) const {
    662   PrimitiveType operand_type = op->operand(0)->shape().element_type();
    663   if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
    664       operand_type == PRED) {
    665     return EmitIntegerBinaryOp(
    666         op, lhs_value, rhs_value,
    667         primitive_util::IsSignedIntegralType(operand_type));
    668   } else if (primitive_util::IsComplexType(operand_type)) {
    669     return EmitComplexBinaryOp(op, lhs_value, rhs_value);
    670   } else {
    671     return EmitFloatBinaryOp(op, lhs_value, rhs_value);
    672   }
    673 }
    674 
    675 StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
    676     const HloInstruction* op, llvm::Value* lhs_value,
    677     llvm::Value* rhs_value) const {
    678   switch (op->opcode()) {
    679     case HloOpcode::kComplex:
    680       return EmitComposeComplex(op, lhs_value, rhs_value);
    681     case HloOpcode::kAdd:
    682       return ir_builder_->CreateFAdd(lhs_value, rhs_value);
    683     case HloOpcode::kSubtract:
    684       return ir_builder_->CreateFSub(lhs_value, rhs_value);
    685     case HloOpcode::kMultiply:
    686       return ir_builder_->CreateFMul(lhs_value, rhs_value);
    687     case HloOpcode::kDivide:
    688       return ir_builder_->CreateFDiv(lhs_value, rhs_value);
    689     case HloOpcode::kRemainder:
    690       return ir_builder_->CreateFRem(lhs_value, rhs_value);
    691     // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
    692     // comparisons always return false when one of the operands is NaN, whereas
    693     // unordered comparisons return true.
    694     //
    695     // We use ordered comparisons for everything except kNe, where we use an
    696     // unordered comparison.  This makes x != y equivalent to !(x == y), and
    697     // matches C++'s semantics.
    698     case HloOpcode::kEq:
    699       return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value,
    700                                      rhs_value, ir_builder_);
    701     case HloOpcode::kNe:
    702       return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value,
    703                                      rhs_value, ir_builder_);
    704     case HloOpcode::kLt:
    705       return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value,
    706                                      rhs_value, ir_builder_);
    707     case HloOpcode::kGt:
    708       return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value,
    709                                      rhs_value, ir_builder_);
    710     case HloOpcode::kLe:
    711       return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value,
    712                                      rhs_value, ir_builder_);
    713     case HloOpcode::kGe:
    714       return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value,
    715                                      rhs_value, ir_builder_);
    716 
    717     case HloOpcode::kMaximum:
    718       return EmitFloatMax(lhs_value, rhs_value);
    719     case HloOpcode::kMinimum:
    720       return EmitFloatMin(lhs_value, rhs_value);
    721     case HloOpcode::kPower:
    722       return EmitPow(op->shape().element_type(), lhs_value, rhs_value);
    723     case HloOpcode::kAtan2:
    724       return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value);
    725     default:
    726       return Unimplemented("binary floating point op '%s'",
    727                            HloOpcodeString(op->opcode()).c_str());
    728   }
    729 }
    730 
    731 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
    732     const HloInstruction* op, llvm::Value* lhs_value,
    733     llvm::Value* rhs_value) const {
    734   switch (op->opcode()) {
    735     case HloOpcode::kAdd:
    736       return EmitComposeComplex(
    737           op,
    738           ir_builder_->CreateFAdd(EmitExtractReal(lhs_value),
    739                                   EmitExtractReal(rhs_value)),
    740           ir_builder_->CreateFAdd(EmitExtractImag(lhs_value),
    741                                   EmitExtractImag(rhs_value)));
    742     case HloOpcode::kSubtract:
    743       return EmitComposeComplex(
    744           op,
    745           ir_builder_->CreateFSub(EmitExtractReal(lhs_value),
    746                                   EmitExtractReal(rhs_value)),
    747           ir_builder_->CreateFSub(EmitExtractImag(lhs_value),
    748                                   EmitExtractImag(rhs_value)));
    749     case HloOpcode::kMultiply:
    750       return EmitComposeComplex(
    751           op,
    752           ir_builder_->CreateFSub(
    753               ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
    754                                       EmitExtractReal(rhs_value)),
    755               ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
    756                                       EmitExtractImag(rhs_value))),
    757           ir_builder_->CreateFAdd(
    758               ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
    759                                       EmitExtractImag(rhs_value)),
    760               ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
    761                                       EmitExtractReal(rhs_value))));
    762     case HloOpcode::kDivide: {
    763       // (a+bi) / (c+di) = ((a+bi)(c-di)) / ((c+di)(c-di))
    764       // = ((ac + bd) + (bc - ad)i) / (c^2 + d^2)
    765       auto rhs_sum_sq = ir_builder_->CreateFAdd(
    766           ir_builder_->CreateFMul(EmitExtractReal(rhs_value),
    767                                   EmitExtractReal(rhs_value)),
    768           ir_builder_->CreateFMul(EmitExtractImag(rhs_value),
    769                                   EmitExtractImag(rhs_value)));
    770       auto type = rhs_sum_sq->getType();
    771       auto zero = llvm::ConstantFP::get(type, 0.0);
    772       auto oeq = ir_builder_->CreateFCmpOEQ(rhs_sum_sq, zero);
    773       auto real_inf_or_nan =
    774           ir_builder_->CreateFDiv(EmitExtractReal(lhs_value), zero);
    775       auto imag_inf_or_nan =
    776           ir_builder_->CreateFDiv(EmitExtractImag(lhs_value), zero);
    777       return ir_builder_->CreateSelect(
    778           oeq, EmitComposeComplex(op, real_inf_or_nan, imag_inf_or_nan),
    779           EmitComposeComplex(
    780               op,
    781               ir_builder_->CreateFDiv(
    782                   ir_builder_->CreateFAdd(
    783                       ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
    784                                               EmitExtractReal(rhs_value)),
    785                       ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
    786                                               EmitExtractImag(rhs_value))),
    787                   rhs_sum_sq),
    788               ir_builder_->CreateFDiv(
    789                   ir_builder_->CreateFSub(
    790                       ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
    791                                               EmitExtractReal(rhs_value)),
    792                       ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
    793                                               EmitExtractImag(rhs_value))),
    794                   rhs_sum_sq)));
    795     }
    796     // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
    797     // comparisons always return false when one of the operands is NaN, whereas
    798     // unordered comparisons return true.
    799     //
    800     // We use ordered comparisons for everything except kNe, where we use an
    801     // unordered comparison.  This makes x != y equivalent to !(x == y), and
    802     // matches C++'s semantics.
    803     case HloOpcode::kEq:
    804       return ir_builder_->CreateAnd(
    805           llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
    806                                   EmitExtractReal(lhs_value),
    807                                   EmitExtractReal(rhs_value), ir_builder_),
    808           llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
    809                                   EmitExtractImag(lhs_value),
    810                                   EmitExtractImag(rhs_value), ir_builder_));
    811     case HloOpcode::kNe:
    812       return ir_builder_->CreateOr(
    813           llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
    814                                   EmitExtractReal(lhs_value),
    815                                   EmitExtractReal(rhs_value), ir_builder_),
    816           llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
    817                                   EmitExtractImag(lhs_value),
    818                                   EmitExtractImag(rhs_value), ir_builder_));
    819 
    820     case HloOpcode::kPower: {
    821       // (a+bi)^(c+di) =
    822       //    (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
    823       //    where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
    824       PrimitiveType component_type =
    825           primitive_util::ComplexComponentType(op->shape().element_type());
    826       auto a = EmitExtractReal(lhs_value);
    827       auto b = EmitExtractImag(lhs_value);
    828       auto c = EmitExtractReal(rhs_value);
    829       auto d = EmitExtractImag(rhs_value);
    830       auto aa_p_bb = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a),
    831                                              ir_builder_->CreateFMul(b, b));
    832       auto one_half = llvm::ConstantFP::get(a->getType(), 0.5);
    833       auto half_c = ir_builder_->CreateFMul(one_half, c);
    834 
    835       TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c,
    836                           EmitPow(component_type, aa_p_bb, half_c));
    837       auto neg_d = ir_builder_->CreateFNeg(d);
    838       TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a));
    839       auto neg_d_arg_lhs = ir_builder_->CreateFMul(neg_d, arg_lhs);
    840       TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs,
    841                           EmitExp(component_type, neg_d_arg_lhs));
    842       auto coeff =
    843           ir_builder_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs);
    844       TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb));
    845       auto half_d = ir_builder_->CreateFMul(one_half, d);
    846       auto q =
    847           ir_builder_->CreateFAdd(ir_builder_->CreateFMul(c, arg_lhs),
    848                                   ir_builder_->CreateFMul(half_d, ln_aa_p_bb));
    849       TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q));
    850       TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q));
    851       return EmitComposeComplex(op, ir_builder_->CreateFMul(coeff, cos_q),
    852                                 ir_builder_->CreateFMul(coeff, sin_q));
    853     }
    854     default:
    855       return Unimplemented("binary complex op '%s'",
    856                            HloOpcodeString(op->opcode()).c_str());
    857   }
    858 }
    859 
    860 llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
    861                                               llvm::Value* rhs_value) const {
    862   return llvm_ir::EmitFloatMax(lhs_value, rhs_value, ir_builder_);
    863 }
    864 
    865 llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
    866                                               llvm::Value* rhs_value) const {
    867   return llvm_ir::EmitFloatMin(lhs_value, rhs_value, ir_builder_);
    868 }
    869 
    870 StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
    871                                                       llvm::Value* x) const {
    872   if (prim_type != F32) {
    873     // TODO(b/34339814): Implement inverse erf for F64.
    874     return Unimplemented(
    875         "Inverse erf is only implemented for element "
    876         "type F32.");
    877   }
    878   auto getFloat = [&](const float f) {
    879     return llvm::ConstantFP::get(ir_builder_->getFloatTy(), f);
    880   };
    881   auto multiply_add = [&](tensorflow::gtl::ArraySlice<float> coefficients,
    882                           llvm::Value* w) {
    883     llvm::Value* p = getFloat(coefficients.front());
    884     coefficients.pop_front();
    885     for (float coefficient : coefficients) {
    886       p = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(p, w),
    887                                   getFloat(coefficient));
    888     }
    889     return p;
    890   };
    891 
    892   // Approximation for inverse error function from
    893   //   Giles, M., "Approximating the erfinv function".
    894   // The approximation has the form:
    895   //   w = log((1-x)*(1+x))
    896   //   if ( w < 5 ) {
    897   //     w = w - 2.5
    898   //     p = sum_{i=1}^n lq[i]*w^i
    899   //   } else {
    900   //     w = sqrt(w) - 3
    901   //     p = sum_{i=1}^n gq[i]*w^i
    902   //   }
    903   //   return p*x
    904   llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration(
    905       module_, llvm::Intrinsic::log, {ir_builder_->getFloatTy()});
    906 
    907   llvm::Value* w = ir_builder_->CreateFNeg(ir_builder_->CreateCall(
    908       logf_fn,
    909       {ir_builder_->CreateFMul(ir_builder_->CreateFSub(getFloat(1.0f), x),
    910                                ir_builder_->CreateFAdd(getFloat(1.0f), x))}));
    911 
    912   llvm::Value* p_addr = llvm_ir::EmitAllocaAtFunctionEntry(
    913       ir_builder_->getFloatTy(), "p.addr", ir_builder_);
    914 
    915   llvm_ir::LlvmIfData if_data =
    916       llvm_ir::EmitIfThenElse(ir_builder_->CreateFCmpOLT(w, getFloat(5.0f)),
    917                               "w_less_than_five", ir_builder_);
    918   // Handle true BB.
    919   SetToFirstInsertPoint(if_data.true_block, ir_builder_);
    920   {
    921     llvm::Value* lw = ir_builder_->CreateFSub(w, getFloat(2.5f));
    922     tensorflow::gtl::ArraySlice<float> lq{
    923         2.81022636e-08f,  3.43273939e-07f, -3.5233877e-06f,
    924         -4.39150654e-06f, 0.00021858087f,  -0.00125372503f,
    925         -0.00417768164f,  0.246640727f,    1.50140941f};
    926     llvm::Value* p = multiply_add(lq, lw);
    927     ir_builder_->CreateStore(p, p_addr);
    928   }
    929 
    930   // Handle false BB.
    931   SetToFirstInsertPoint(if_data.false_block, ir_builder_);
    932   {
    933     llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
    934         module_, llvm::Intrinsic::sqrt, {ir_builder_->getFloatTy()});
    935 
    936     llvm::Value* gw = ir_builder_->CreateFSub(
    937         ir_builder_->CreateCall(sqrtf_fn, {w}), getFloat(3.0f));
    938     tensorflow::gtl::ArraySlice<float> gq{
    939         -0.000200214257f, 0.000100950558f, 0.00134934322f,
    940         -0.00367342844f,  0.00573950773f,  -0.0076224613f,
    941         0.00943887047f,   1.00167406f,     2.83297682f};
    942     llvm::Value* p = multiply_add(gq, gw);
    943     ir_builder_->CreateStore(p, p_addr);
    944   }
    945 
    946   SetToFirstInsertPoint(if_data.after_block, ir_builder_);
    947   llvm::Value* p = ir_builder_->CreateLoad(p_addr);
    948   return ir_builder_->CreateFMul(p, x);
    949 }
    950 
    951 StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(
    952     PrimitiveType prim_type, llvm::Value* value) const {
    953   // Compute erfcinv(value) by calculating erfinv(1.0 - value).
    954   auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
    955   auto one = llvm::ConstantFP::get(type, 1.0);
    956   return EmitErfInv(prim_type, ir_builder_->CreateFSub(one, value));
    957 }
    958 
    959 StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type,
    960                                                    llvm::Value* value) const {
    961   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value},
    962                                       {value->getType()}, ir_builder_);
    963 }
    964 
    965 StatusOr<llvm::Value*> ElementalIrEmitter::EmitSin(PrimitiveType prim_type,
    966                                                    llvm::Value* value) const {
    967   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value},
    968                                       {value->getType()}, ir_builder_);
    969 }
    970 
    971 StatusOr<llvm::Value*> ElementalIrEmitter::EmitCos(PrimitiveType prim_type,
    972                                                    llvm::Value* value) const {
    973   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value},
    974                                       {value->getType()}, ir_builder_);
    975 }
    976 
    977 StatusOr<llvm::Value*> ElementalIrEmitter::EmitExp(PrimitiveType prim_type,
    978                                                    llvm::Value* value) const {
    979   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value},
    980                                       {value->getType()}, ir_builder_);
    981 }
    982 
    983 StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type,
    984                                                    llvm::Value* lhs,
    985                                                    llvm::Value* rhs) const {
    986   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs},
    987                                       {lhs->getType()}, ir_builder_);
    988 }
    989 
    990 StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
    991                                                      llvm::Value* lhs,
    992                                                      llvm::Value* rhs) const {
    993   return Unimplemented("atan2");
    994 }
    995 
    996 StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
    997     const HloInstruction* hlo, llvm::Value* x) const {
    998   if (hlo->operand(0)->shape().element_type() != F32) {
    999     return Unimplemented("reduce-precision only implemented for F32");
   1000   }
   1001   return EmitReducePrecisionFloat(x, /*exponent_bits=*/hlo->exponent_bits(),
   1002                                   /*mantissa_bits=*/hlo->mantissa_bits(),
   1003                                   ir_builder_);
   1004 }
   1005 
   1006 StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
   1007     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value,
   1008     bool is_signed) const {
   1009   switch (op->opcode()) {
   1010     // TODO(jingyue): add the "nsw" attribute for signed types.
   1011     case HloOpcode::kAdd:
   1012       return ir_builder_->CreateAdd(lhs_value, rhs_value);
   1013     case HloOpcode::kSubtract:
   1014       return ir_builder_->CreateSub(lhs_value, rhs_value);
   1015     case HloOpcode::kMultiply:
   1016       return ir_builder_->CreateMul(lhs_value, rhs_value);
   1017     case HloOpcode::kDivide:
   1018       return is_signed ? ir_builder_->CreateSDiv(lhs_value, rhs_value)
   1019                        : ir_builder_->CreateUDiv(lhs_value, rhs_value);
   1020     case HloOpcode::kRemainder:
   1021       return is_signed ? ir_builder_->CreateSRem(lhs_value, rhs_value)
   1022                        : ir_builder_->CreateURem(lhs_value, rhs_value);
   1023     case HloOpcode::kEq:
   1024       return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value,
   1025                                      rhs_value, ir_builder_);
   1026     case HloOpcode::kNe:
   1027       return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value,
   1028                                      rhs_value, ir_builder_);
   1029     case HloOpcode::kLt:
   1030       return llvm_ir::EmitComparison(
   1031           is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT,
   1032           lhs_value, rhs_value, ir_builder_);
   1033     case HloOpcode::kGt:
   1034       return llvm_ir::EmitComparison(
   1035           is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT,
   1036           lhs_value, rhs_value, ir_builder_);
   1037     case HloOpcode::kLe:
   1038       return llvm_ir::EmitComparison(
   1039           is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE,
   1040           lhs_value, rhs_value, ir_builder_);
   1041     case HloOpcode::kGe:
   1042       return llvm_ir::EmitComparison(
   1043           is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
   1044           lhs_value, rhs_value, ir_builder_);
   1045     case HloOpcode::kMinimum:
   1046       return EmitIntegralMin(lhs_value, rhs_value, is_signed);
   1047     case HloOpcode::kMaximum:
   1048       return EmitIntegralMax(lhs_value, rhs_value, is_signed);
   1049     case HloOpcode::kAnd:
   1050       return ir_builder_->CreateAnd(lhs_value, rhs_value);
   1051     case HloOpcode::kOr:
   1052       return ir_builder_->CreateOr(lhs_value, rhs_value);
   1053     case HloOpcode::kShiftLeft:
   1054       return ir_builder_->CreateShl(lhs_value, rhs_value);
   1055     case HloOpcode::kShiftRightArithmetic:
   1056       return ir_builder_->CreateAShr(lhs_value, rhs_value);
   1057     case HloOpcode::kShiftRightLogical:
   1058       return ir_builder_->CreateLShr(lhs_value, rhs_value);
   1059     default:
   1060       return Unimplemented("binary integer op '%s'",
   1061                            HloOpcodeString(op->opcode()).c_str());
   1062   }
   1063 }
   1064 
   1065 llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value,
   1066                                                  llvm::Value* rhs_value,
   1067                                                  bool is_signed) const {
   1068   return ir_builder_->CreateSelect(
   1069       ir_builder_->CreateICmp(
   1070           is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE,
   1071           lhs_value, rhs_value),
   1072       lhs_value, rhs_value);
   1073 }
   1074 
   1075 llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value,
   1076                                                  llvm::Value* rhs_value,
   1077                                                  bool is_signed) const {
   1078   return ir_builder_->CreateSelect(
   1079       ir_builder_->CreateICmp(
   1080           is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE,
   1081           lhs_value, rhs_value),
   1082       lhs_value, rhs_value);
   1083 }
   1084 
   1085 llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
   1086     const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo,
   1087     int64 operand_no) const {
   1088   CHECK(hlo.IsElementwise())
   1089       << "HLO " << hlo.ToString() << " is not elementwise.";
   1090 
   1091   const Shape& operand_shape = hlo.operand(operand_no)->shape();
   1092   // If the operand is scalar, the source index is always {}.
   1093   if (ShapeUtil::IsScalar(operand_shape)) {
   1094     return llvm_ir::IrArray::Index();
   1095   }
   1096 
   1097   // If no implicit broadcast is needed for this operand, returns the target
   1098   // index as the source index.
   1099   if (ShapeUtil::CompatibleIgnoringElementType(operand_shape, hlo.shape())) {
   1100     return target_index;
   1101   }
   1102 
   1103   // If implicit broadcast is needed, the source dimensions that are broadcast
   1104   // have index 0.
   1105   CHECK_EQ(ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(hlo.shape()));
   1106   llvm_ir::IrArray::Index source_index;
   1107   for (int64 i = 0; i < ShapeUtil::Rank(hlo.shape()); ++i) {
   1108     if (hlo.shape().dimensions(i) == operand_shape.dimensions(i)) {
   1109       source_index.push_back(target_index[i]);
   1110     } else {
   1111       CHECK_EQ(1, operand_shape.dimensions(i));
   1112       source_index.push_back(ir_builder_->getInt64(0));
   1113     }
   1114   }
   1115   return source_index;
   1116 }
   1117 
   1118 llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
   1119     const HloInstruction* hlo,
   1120     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator)
   1121     const {
   1122   PrimitiveType param_prim_type = hlo->operand(0)->shape().element_type();
   1123   llvm::Type* param_ir_type =
   1124       llvm_ir::PrimitiveTypeToIrType(param_prim_type, module_);
   1125 
   1126   // Same values as PCG library
   1127   // https://github.com/imneme/pcg-c/blob/master/include/pcg_variants.h
   1128   llvm::Value* multiplier = ir_builder_->getInt(
   1129       llvm::APInt(128, {0x4385DF649FCCF645, 0x2360ED051FC65DA4}));
   1130   llvm::Value* increment = ir_builder_->getInt(
   1131       llvm::APInt(128, {0x14057B7EF767814F, 0x5851F42D4C957F2D}));
   1132 
   1133   auto random_value = [hlo]() {
   1134     const HloModule* module =
   1135         hlo->IsFused() ? hlo->parent()->FusionInstruction()->parent()->parent()
   1136                        : hlo->parent()->parent();
   1137     return module->RandomNew64();
   1138   };
   1139 
   1140   // Seed each RNG emitter with a new 64-bit seed from the HloModule. If the
   1141   // compilation order is deterministic (i.e., RandomNew64 invocation order is
   1142   // deterministic), then the order of RNG is deterministic for a given seed and
   1143   // hence tests will be deterministic.
   1144   // If the user provides a global seed instruction then we only use 64-bits of
   1145   // the host's random number generator to seed the 128 bit value with the other
   1146   // 64-bits is due to a user specified global seed instruction.
   1147   // Create a GlobalVariable to maintain state between invocations. There is a
   1148   // bug in NVPTX with GlobalVariable and 128 bit values, so using 2 64-bit
   1149   // values.
   1150   llvm::GlobalVariable* state_ptr0 = new llvm::GlobalVariable(
   1151       /*M=*/*module_,
   1152       /*Ty=*/ir_builder_->getInt64Ty(),
   1153       /*isConstant=*/false,
   1154       /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
   1155       /*Initializer=*/ir_builder_->getInt64(random_value()),
   1156       /*Name=*/"state_ptr0");
   1157   uint64 graph_seed = hlo_module_config_.seed() != 0 ? hlo_module_config_.seed()
   1158                                                      : random_value();
   1159   llvm::GlobalVariable* state_ptr1 = new llvm::GlobalVariable(
   1160       /*M=*/*module_,
   1161       /*Ty=*/ir_builder_->getInt64Ty(),
   1162       /*isConstant=*/false,
   1163       /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
   1164       /*Initializer=*/ir_builder_->getInt64(graph_seed),
   1165       /*Name=*/"state_ptr1");
   1166 
   1167   // We want each thread to use its own stream, so we modify the increment per
   1168   // thread. We want the increment to remain odd, so we shift the thread id left
   1169   // 1 and add it to the increment.
   1170   increment = ir_builder_->CreateAdd(increment,
   1171                                      ir_builder_->CreateShl(EmitThreadId(), 1));
   1172 
   1173   // PCG-XSL-RR algorithm
   1174   // http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf
   1175   //   state = multiplier * state + increment
   1176   //   return uint64_t(state ^ (state >> 64))) >>> (state >> 122)
   1177   // where ">>>" is bitwise rotation
   1178   auto get_next_i64 = [=]() {
   1179     llvm::Value* state0 = ir_builder_->CreateZExtOrTrunc(
   1180         ir_builder_->CreateLoad(state_ptr0, "state0"),
   1181         ir_builder_->getInt128Ty());
   1182     llvm::Value* state1 = ir_builder_->CreateShl(
   1183         ir_builder_->CreateZExtOrTrunc(
   1184             ir_builder_->CreateLoad(state_ptr1, "state1"),
   1185             ir_builder_->getInt128Ty()),
   1186         64);
   1187     llvm::Value* state = ir_builder_->CreateOr(state0, state1);
   1188     llvm::Value* updated = ir_builder_->CreateAdd(
   1189         ir_builder_->CreateMul(state, multiplier), increment);
   1190     ir_builder_->CreateStore(
   1191         ir_builder_->CreateTrunc(updated, ir_builder_->getInt64Ty()),
   1192         state_ptr0);
   1193     ir_builder_->CreateStore(
   1194         ir_builder_->CreateTrunc(ir_builder_->CreateLShr(updated, 64),
   1195                                  ir_builder_->getInt64Ty()),
   1196         state_ptr1);
   1197 
   1198     return llvm_ir::CreateRor(
   1199         ir_builder_->CreateTrunc(
   1200             ir_builder_->CreateXor(state, ir_builder_->CreateLShr(state, 64)),
   1201             ir_builder_->getInt64Ty()),
   1202         ir_builder_->CreateTrunc(ir_builder_->CreateLShr(state, 122),
   1203                                  ir_builder_->getInt64Ty()),
   1204         ir_builder_);
   1205   };
   1206 
   1207   auto get_next_uniform_float = [=]() {
   1208     return ir_builder_->CreateFDiv(
   1209         ir_builder_->CreateUIToFP(get_next_i64(), param_ir_type),
   1210         llvm::ConstantFP::get(param_ir_type, 0x1p64));
   1211   };
   1212 
   1213   return [=](const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
   1214     switch (hlo->random_distribution()) {
   1215       case RNG_UNIFORM: {
   1216         TF_ASSIGN_OR_RETURN(llvm::Value * p,
   1217                             operand_to_generator.at(hlo->operand(0))(index));
   1218         TF_ASSIGN_OR_RETURN(llvm::Value * q,
   1219                             operand_to_generator.at(hlo->operand(1))(index));
   1220         if (primitive_util::IsFloatingPointType(param_prim_type)) {
   1221           return ir_builder_->CreateFAdd(
   1222               ir_builder_->CreateFMul(ir_builder_->CreateFSub(q, p),
   1223                                       get_next_uniform_float()),
   1224               p);
   1225         } else {
   1226           auto r = ir_builder_->CreateSub(q, p);
   1227           auto leading_zeros = llvm_ir::EmitCallToIntrinsic(
   1228               llvm::Intrinsic::ctlz, {r, ir_builder_->getInt1(true)},
   1229               {param_ir_type}, ir_builder_);
   1230           auto in_block = ir_builder_->GetInsertBlock();
   1231 
   1232           // A terminator should be present iff we're emitting code
   1233           // into the middle (as opposed to the end) of a basic block.
   1234           CHECK_EQ(ir_builder_->GetInsertPoint() == in_block->end(),
   1235                    in_block->getTerminator() == nullptr);
   1236 
   1237           llvm::BasicBlock* body_block;
   1238           llvm::BasicBlock* out_block;
   1239 
   1240           if (ir_builder_->GetInsertPoint() == in_block->end()) {
   1241             body_block = llvm_ir::CreateBasicBlock(
   1242                 nullptr, IrName(hlo, "rng_body"), ir_builder_);
   1243             out_block = llvm_ir::CreateBasicBlock(
   1244                 nullptr, IrName(hlo, "rng_out"), ir_builder_);
   1245             llvm::BranchInst::Create(body_block, in_block);
   1246           } else {
   1247             body_block = in_block->splitBasicBlock(
   1248                 ir_builder_->GetInsertPoint(), "rng_body");
   1249             out_block = body_block->splitBasicBlock(
   1250                 ir_builder_->GetInsertPoint(), "rng_out");
   1251             body_block->getTerminator()->eraseFromParent();
   1252           }
   1253 
   1254           SetToFirstInsertPoint(body_block, ir_builder_);
   1255           auto random = ir_builder_->CreateAnd(
   1256               ir_builder_->CreateZExtOrTrunc(get_next_i64(), param_ir_type),
   1257               ir_builder_->CreateLShr(llvm::ConstantInt::get(param_ir_type, ~0),
   1258                                       leading_zeros));
   1259           llvm::BranchInst::Create(out_block, body_block,
   1260                                    ir_builder_->CreateICmpULT(random, r),
   1261                                    body_block);
   1262           SetToFirstInsertPoint(out_block, ir_builder_);
   1263           return ir_builder_->CreateAdd(
   1264               p, ir_builder_->CreateSelect(
   1265                      ir_builder_->CreateICmpEQ(p, q),
   1266                      llvm::ConstantInt::get(param_ir_type, 0), random));
   1267         }
   1268       }
   1269       case RNG_NORMAL: {
   1270         TF_ASSIGN_OR_RETURN(llvm::Value * m,
   1271                             operand_to_generator.at(hlo->operand(0))(index));
   1272         TF_ASSIGN_OR_RETURN(llvm::Value * s,
   1273                             operand_to_generator.at(hlo->operand(1))(index));
   1274         TF_ASSIGN_OR_RETURN(
   1275             llvm::Value * r,
   1276             EmitErfcInv(param_prim_type,
   1277                         ir_builder_->CreateFMul(
   1278                             llvm::ConstantFP::get(param_ir_type, 2.0),
   1279                             get_next_uniform_float())));
   1280         return ir_builder_->CreateFAdd(ir_builder_->CreateFMul(r, s), m);
   1281       }
   1282       default:
   1283         return InvalidArgument(
   1284             "unhandled distribution %s",
   1285             RandomDistribution_Name(hlo->random_distribution()).c_str());
   1286     }
   1287   };
   1288 }
   1289 
   1290 llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
   1291     const HloInstruction* hlo,
   1292     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator)
   1293     const {
   1294   switch (hlo->opcode()) {
   1295     case HloOpcode::kAbs:
   1296     case HloOpcode::kRoundNearestAfz:
   1297     case HloOpcode::kCeil:
   1298     case HloOpcode::kConvert:
   1299     case HloOpcode::kBitcastConvert:
   1300     case HloOpcode::kCopy:
   1301     case HloOpcode::kCos:
   1302     case HloOpcode::kExp:
   1303     case HloOpcode::kFloor:
   1304     case HloOpcode::kImag:
   1305     case HloOpcode::kIsFinite:
   1306     case HloOpcode::kLog:
   1307     case HloOpcode::kNegate:
   1308     case HloOpcode::kNot:
   1309     case HloOpcode::kReal:
   1310     case HloOpcode::kSign:
   1311     case HloOpcode::kSin:
   1312     case HloOpcode::kTanh:
   1313       return [this, hlo, &operand_to_generator](
   1314                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
   1315         TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
   1316                             operand_to_generator.at(hlo->operand(0))(
   1317                                 ElementwiseSourceIndex(index, *hlo, 0)));
   1318         return EmitUnaryOp(hlo, operand_value);
   1319       };
   1320     case HloOpcode::kAdd:
   1321     case HloOpcode::kAnd:
   1322     case HloOpcode::kAtan2:
   1323     case HloOpcode::kComplex:
   1324     case HloOpcode::kDivide:
   1325     case HloOpcode::kEq:
   1326     case HloOpcode::kGe:
   1327     case HloOpcode::kGt:
   1328     case HloOpcode::kLe:
   1329     case HloOpcode::kLt:
   1330     case HloOpcode::kMaximum:
   1331     case HloOpcode::kMinimum:
   1332     case HloOpcode::kMultiply:
   1333     case HloOpcode::kNe:
   1334     case HloOpcode::kOr:
   1335     case HloOpcode::kPower:
   1336     case HloOpcode::kRemainder:
   1337     case HloOpcode::kShiftLeft:
   1338     case HloOpcode::kShiftRightArithmetic:
   1339     case HloOpcode::kShiftRightLogical:
   1340     case HloOpcode::kSubtract:
   1341       return [this, hlo, &operand_to_generator](
   1342                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
   1343         const HloInstruction* lhs = hlo->operand(0);
   1344         const HloInstruction* rhs = hlo->operand(1);
   1345         TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value,
   1346                             operand_to_generator.at(lhs)(
   1347                                 ElementwiseSourceIndex(index, *hlo, 0)));
   1348         TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value,
   1349                             operand_to_generator.at(rhs)(
   1350                                 ElementwiseSourceIndex(index, *hlo, 1)));
   1351         return EmitBinaryOp(hlo, lhs_value, rhs_value);
   1352       };
   1353     case HloOpcode::kSelect:
   1354       return [this, hlo, &operand_to_generator](
   1355                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
   1356         TF_ASSIGN_OR_RETURN(llvm::Value * pred_value,
   1357                             operand_to_generator.at(hlo->operand(0))(
   1358                                 ElementwiseSourceIndex(index, *hlo, 0)));
   1359         TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value,
   1360                             operand_to_generator.at(hlo->operand(1))(
   1361                                 ElementwiseSourceIndex(index, *hlo, 1)));
   1362         TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value,
   1363                             operand_to_generator.at(hlo->operand(2))(
   1364                                 ElementwiseSourceIndex(index, *hlo, 2)));
   1365         return ir_builder_->CreateSelect(
   1366             ir_builder_->CreateTrunc(pred_value, ir_builder_->getInt1Ty()),
   1367             on_true_value, on_false_value);
   1368       };
   1369     case HloOpcode::kClamp:
   1370       return [this, hlo, &operand_to_generator](
   1371                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
   1372         TF_ASSIGN_OR_RETURN(llvm::Value * min_value,
   1373                             operand_to_generator.at(hlo->operand(0))(
   1374                                 ElementwiseSourceIndex(index, *hlo, 0)));
   1375         TF_ASSIGN_OR_RETURN(llvm::Value * arg_value,
   1376                             operand_to_generator.at(hlo->operand(1))(
   1377                                 ElementwiseSourceIndex(index, *hlo, 1)));
   1378         TF_ASSIGN_OR_RETURN(llvm::Value * max_value,
   1379                             operand_to_generator.at(hlo->operand(2))(
   1380                                 ElementwiseSourceIndex(index, *hlo, 2)));
   1381         PrimitiveType prim_type = hlo->shape().element_type();
   1382         if (primitive_util::IsFloatingPointType(prim_type)) {
   1383           return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value));
   1384         } else if (primitive_util::IsIntegralType(prim_type)) {
   1385           bool is_signed = primitive_util::IsSignedIntegralType(prim_type);
   1386           return EmitIntegralMin(
   1387               max_value, EmitIntegralMax(min_value, arg_value, is_signed),
   1388               is_signed);
   1389         } else {
   1390           return Unimplemented("Clamp unimplemented for %s",
   1391                                PrimitiveType_Name(prim_type).c_str());
   1392         }
   1393       };
   1394     case HloOpcode::kReducePrecision:
   1395       return [this, hlo, &operand_to_generator](
   1396                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
   1397         TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
   1398                             operand_to_generator.at(hlo->operand(0))(
   1399                                 ElementwiseSourceIndex(index, *hlo, 0)));
   1400         return EmitReducePrecision(hlo, operand_value);
   1401       };
   1402     case HloOpcode::kConcatenate:
   1403       return [this, hlo, &operand_to_generator](
   1404                  const IrArray::Index target_index) -> StatusOr<llvm::Value*> {
   1405         const int64 concat_dim = hlo->dimensions(0);
   1406         auto source_index = target_index;
   1407 
   1408         llvm::BasicBlock* init_block = ir_builder_->GetInsertBlock();
   1409 
   1410         // A terminator should be present iff we're emitting code
   1411         // into the middle (as opposed to the end) of a basic block.
   1412         CHECK_EQ(ir_builder_->GetInsertPoint() == init_block->end(),
   1413                  init_block->getTerminator() == nullptr);
   1414 
   1415         llvm::BasicBlock* exit_block;
   1416         if (ir_builder_->GetInsertPoint() == init_block->end()) {
   1417           exit_block = llvm_ir::CreateBasicBlock(
   1418               /*insert_before=*/nullptr, IrName(hlo, "merge"), ir_builder_);
   1419         } else {
   1420           exit_block = init_block->splitBasicBlock(
   1421               ir_builder_->GetInsertPoint(), AsStringRef(IrName(hlo, "merge")));
   1422           init_block->getTerminator()->eraseFromParent();
   1423         }
   1424 
   1425         llvm_ir::SetToFirstInsertPoint(exit_block, ir_builder_);
   1426         llvm::PHINode* output =
   1427             ir_builder_->CreatePHI(llvm_ir::PrimitiveTypeToIrType(
   1428                                        hlo->shape().element_type(), module_),
   1429                                    hlo->operands().size());
   1430         auto prior_insert_point = ir_builder_->GetInsertPoint();
   1431 
   1432         ir_builder_->SetInsertPoint(init_block);
   1433 
   1434         for (int64 operand_idx = 0; operand_idx < hlo->operand_count();
   1435              ++operand_idx) {
   1436           const HloInstruction* operand = hlo->operand(operand_idx);
   1437           auto true_block = llvm_ir::CreateBasicBlock(
   1438               exit_block, StrCat("concat_index_from_operand", operand_idx),
   1439               ir_builder_);
   1440           auto false_block = llvm_ir::CreateBasicBlock(
   1441               exit_block, StrCat("concat_index_not_from_operand", operand_idx),
   1442               ir_builder_);
   1443           auto concat_dim_size =
   1444               llvm::ConstantInt::get(source_index[concat_dim]->getType(),
   1445                                      operand->shape().dimensions(concat_dim));
   1446           ir_builder_->CreateCondBr(
   1447               ir_builder_->CreateICmpULT(source_index[concat_dim],
   1448                                          concat_dim_size),
   1449               true_block, false_block);
   1450 
   1451           // Create the terminator of the true block before calling operand
   1452           // generators, because they require non-degenerate basic blocks.
   1453           ir_builder_->SetInsertPoint(
   1454               llvm::BranchInst::Create(exit_block, /*InsertAtEnd=*/true_block));
   1455           TF_ASSIGN_OR_RETURN(llvm::Value * value,
   1456                               operand_to_generator.at(operand)(source_index));
   1457           output->addIncoming(value, ir_builder_->GetInsertBlock());
   1458 
   1459           // Subtract the size of the concat dimension of the current operand
   1460           // from the source index.
   1461           ir_builder_->SetInsertPoint(false_block);
   1462           source_index[concat_dim] =
   1463               ir_builder_->CreateSub(source_index[concat_dim], concat_dim_size);
   1464         }
   1465 
   1466         ir_builder_->CreateUnreachable();
   1467         ir_builder_->SetInsertPoint(exit_block, prior_insert_point);
   1468         return output;
   1469       };
   1470     case HloOpcode::kReverse:
   1471       return [this, hlo, &operand_to_generator](
   1472                  const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
   1473         const HloInstruction* operand = hlo->operand(0);
   1474         auto source_index = target_index;
   1475         for (int64 dim : hlo->dimensions()) {
   1476           source_index[dim] = ir_builder_->CreateSub(
   1477               llvm::ConstantInt::get(target_index[dim]->getType(),
   1478                                      hlo->shape().dimensions(dim) - 1),
   1479               target_index[dim]);
   1480         }
   1481         return operand_to_generator.at(operand)(source_index);
   1482       };
   1483     case HloOpcode::kBroadcast:
   1484       return [this, hlo, &operand_to_generator](
   1485                  const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
   1486         // The `dimensions` member of the broadcast instruction maps from
   1487         // input dimensions to output dimensions.
   1488         const HloInstruction* operand = hlo->operand(0);
   1489         int64 rank = ShapeUtil::Rank(operand->shape());
   1490         IrArray::Index source_index(rank);
   1491         for (int64 i = 0; i < rank; ++i) {
   1492           source_index[i] = target_index[hlo->dimensions(i)];
   1493         }
   1494         return operand_to_generator.at(operand)(source_index);
   1495       };
   1496     case HloOpcode::kSlice:
   1497       return [this, hlo, &operand_to_generator](
   1498                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
   1499         IrArray::Index sliced_index = index.SourceIndexOfSlice(
   1500             /*shape=*/hlo->shape(), /*starts=*/hlo->slice_starts(),
   1501             /*strides=*/hlo->slice_strides(), /*builder=*/ir_builder_);
   1502         return operand_to_generator.at(hlo->operand(0))(sliced_index);
   1503       };
   1504     case HloOpcode::kDynamicSlice:
   1505       return [this, hlo, &operand_to_generator](
   1506                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
   1507         // Emit IR to read dynamic start indices from hlo->operand(1).
   1508         const HloInstruction* input_hlo = hlo->operand(0);
   1509         const int64 rank = ShapeUtil::Rank(input_hlo->shape());
   1510         llvm_ir::IrArray::Index slice_start_index(rank);
   1511         for (int64 i = 0; i < rank; ++i) {
   1512           llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
   1513           TF_ASSIGN_OR_RETURN(
   1514               llvm::Value * start_index_value,
   1515               operand_to_generator.at(hlo->operand(1))(dim_index));
   1516           start_index_value->setName(
   1517               AsStringRef(IrName(hlo, StrCat("start_idx", i))));
   1518           slice_start_index[i] = start_index_value;
   1519         }
   1520 
   1521         llvm_ir::IrArray::Index input_index(rank);
   1522         for (int64 i = 0; i < rank; ++i) {
   1523           // Emit IR which computes:
   1524           //   input_index = (start_index + offset_index) % dim_size
   1525           // Security note: this is the code that keeps the indices in-bounds.
   1526           llvm::Value* dim_size = llvm::ConstantInt::get(
   1527               index[i]->getType(), input_hlo->shape().dimensions(i));
   1528           llvm::Value* start_index = ir_builder_->CreateZExtOrBitCast(
   1529               slice_start_index[i], index[i]->getType());
   1530           input_index[i] = ir_builder_->CreateURem(
   1531               ir_builder_->CreateAdd(start_index, index[i]), dim_size);
   1532         }
   1533         return operand_to_generator.at(input_hlo)(input_index);
   1534       };
   1535     case HloOpcode::kDynamicUpdateSlice:
   1536       return [this, hlo, &operand_to_generator](
   1537                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
   1538         const HloInstruction* input_hlo = hlo->operand(0);
   1539         const HloInstruction* update_hlo = hlo->operand(1);
   1540         const HloInstruction* start_hlo = hlo->operand(2);
   1541         // Calculate slice start/end indices.
   1542         const int64 rank = ShapeUtil::Rank(input_hlo->shape());
   1543         llvm_ir::IrArray::Index slice_start_index(rank);
   1544         llvm_ir::IrArray::Index slice_limit_index(rank);
   1545         // Slice starts at update[index - slice_start_index_adjusted],
   1546         // where adjusted value = slice_start_index when in bounds, and
   1547         // adjusted value = slice_start_index - input_dim, when wrapping.
   1548         llvm_ir::IrArray::Index slice_start_index_adjusted(rank);
   1549 
   1550         // Slice intersection gathers (ANDs) conditions on all ranks for which
   1551         // 'input' is set to 'update'
   1552         llvm::Value* slice_intersection = ir_builder_->getTrue();
   1553 
   1554         for (int64 i = 0; i < rank; ++i) {
   1555           // Emit IR to read dynamic start indices from 'start_hlo'.
   1556           llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
   1557           TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value,
   1558                               operand_to_generator.at(start_hlo)(dim_index));
   1559           start_index_value->setName(
   1560               AsStringRef(IrName(hlo, StrCat("start_idx", i))));
   1561           slice_start_index[i] = ir_builder_->CreateZExtOrBitCast(
   1562               start_index_value, index[i]->getType());
   1563 
   1564           llvm::Value* input_dim_size = llvm::ConstantInt::get(
   1565               index[i]->getType(), input_hlo->shape().dimensions(i));
   1566           llvm::Value* update_dim_size = llvm::ConstantInt::get(
   1567               index[i]->getType(), update_hlo->shape().dimensions(i));
   1568 
   1569           // Generate code to handle wrapping semantics:
   1570           // slice_start_index[i] = slice_start_index[i] % input_dim_size;
   1571           // slice_limit_index[i] = slice_start_index[i] + update_dim_size.
   1572           // slice_start_index[i] is updated in place and it will now be in
   1573           // range. slice_limit_index[i] may be out of range, and it's being
   1574           // URem-ed below if so.
   1575           slice_start_index[i] =
   1576               ir_builder_->CreateURem(slice_start_index[i], input_dim_size);
   1577           slice_limit_index[i] =
   1578               ir_builder_->CreateAdd(slice_start_index[i], update_dim_size);
   1579 
   1580           // Test if slice_limit_index[i] is in bounds
   1581           llvm::Value* in_bounds =
   1582               ir_builder_->CreateICmpULE(slice_limit_index[i], input_dim_size);
   1583           llvm_ir::LlvmIfData if_in_bounds =
   1584               llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
   1585 
   1586           // Handle true BB (slice_limit_index[i] <= input_dim_size).
   1587           SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_);
   1588           // Check that index[i] >= slice_start_index[i] &&
   1589           //            index[i] < slice_limit_index[i]
   1590           llvm::Value* slice_intersection_in_bounds = ir_builder_->CreateAnd(
   1591               slice_intersection,
   1592               ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]),
   1593               "slice_intersection_in");
   1594           slice_intersection_in_bounds = ir_builder_->CreateAnd(
   1595               slice_intersection_in_bounds,
   1596               ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]),
   1597               "slice_intersection_in");
   1598 
   1599           // Handle false BB (slice_limit_index[i] > input_dim_size).
   1600           SetToFirstInsertPoint(if_in_bounds.false_block, ir_builder_);
   1601           // Check that index[i] >= slice_start_index[i] ||
   1602           //            index[i] < slice_limit_index[i]%input_dim_size.
   1603           llvm::Value* index_wraps = ir_builder_->CreateICmpSLT(
   1604               index[i],
   1605               ir_builder_->CreateURem(slice_limit_index[i], input_dim_size));
   1606           llvm::Value* slice_intersection_or = ir_builder_->CreateOr(
   1607               ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]),
   1608               index_wraps, "slice_intersection_out");
   1609           llvm::Value* slice_intersection_out_of_bounds =
   1610               ir_builder_->CreateAnd(slice_intersection, slice_intersection_or,
   1611                                      "slice_intersection_out");
   1612           // Create value for slice_start_index_adjusted[i] when out of bounds.
   1613           // If within out-of-bounds if.
   1614           llvm_ir::LlvmIfData if_start_needs_adjustment =
   1615               llvm_ir::EmitIfThenElse(index_wraps, "adjust_start", ir_builder_);
   1616           SetToFirstInsertPoint(if_start_needs_adjustment.true_block,
   1617                                 ir_builder_);
   1618           llvm::Value* slice_start_index_adjusted_oob =
   1619               ir_builder_->CreateSub(slice_start_index[i], input_dim_size);
   1620           SetToFirstInsertPoint(if_start_needs_adjustment.after_block,
   1621                                 ir_builder_);
   1622           llvm::PHINode* slice_start_index_adjusted_phi =
   1623               ir_builder_->CreatePHI(slice_start_index_adjusted_oob->getType(),
   1624                                      2);
   1625           slice_start_index_adjusted_phi->addIncoming(
   1626               slice_start_index_adjusted_oob,
   1627               if_start_needs_adjustment.true_block);
   1628           slice_start_index_adjusted_phi->addIncoming(
   1629               slice_start_index[i], if_start_needs_adjustment.false_block);
   1630           // End of if within if.
   1631 
   1632           // After checking in/out of bounds.
   1633           SetToFirstInsertPoint(if_in_bounds.after_block, ir_builder_);
   1634           llvm::PHINode* phi_slice_intersection =
   1635               ir_builder_->CreatePHI(slice_intersection->getType(), 2);
   1636           phi_slice_intersection->addIncoming(slice_intersection_in_bounds,
   1637                                               if_in_bounds.true_block);
   1638           phi_slice_intersection->addIncoming(
   1639               slice_intersection_out_of_bounds,
   1640               if_start_needs_adjustment.after_block);
   1641           slice_intersection = phi_slice_intersection;
   1642 
   1643           llvm::PHINode* phi_index =
   1644               ir_builder_->CreatePHI(slice_start_index[i]->getType(), 2);
   1645           phi_index->addIncoming(slice_start_index[i], if_in_bounds.true_block);
   1646           phi_index->addIncoming(slice_start_index_adjusted_phi,
   1647                                  if_start_needs_adjustment.after_block);
   1648           slice_start_index_adjusted[i] = phi_index;
   1649         }
   1650 
   1651         // Emit:
   1652         // if (slice_intersection) -> return data from 'update'.
   1653         // else                    -> return data from 'input'.
   1654         llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
   1655             llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
   1656                                            module_),
   1657             "ret_value_addr", ir_builder_);
   1658         llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
   1659             slice_intersection, "slice_intersection", ir_builder_);
   1660 
   1661         // Handle true BB (return data from 'update')
   1662         SetToFirstInsertPoint(if_data.true_block, ir_builder_);
   1663         // Compute update index for intersection case.
   1664         llvm_ir::IrArray::Index update_index(rank);
   1665         for (int64 i = 0; i < rank; ++i) {
   1666           llvm::Value* update_dim_size = llvm::ConstantInt::get(
   1667               index[i]->getType(), update_hlo->shape().dimensions(i));
   1668           // NOTE: Subtraction will be positive due to bounds checking above.
   1669           update_index[i] = ir_builder_->CreateURem(
   1670               ir_builder_->CreateSub(index[i], slice_start_index_adjusted[i]),
   1671               update_dim_size);
   1672         }
   1673         TF_ASSIGN_OR_RETURN(llvm::Value * true_value,
   1674                             operand_to_generator.at(update_hlo)(update_index));
   1675         ir_builder_->CreateStore(true_value, ret_value_addr);
   1676 
   1677         // Handle false BB (return data from 'input')
   1678         SetToFirstInsertPoint(if_data.false_block, ir_builder_);
   1679         TF_ASSIGN_OR_RETURN(llvm::Value * false_value,
   1680                             operand_to_generator.at(input_hlo)(index));
   1681         ir_builder_->CreateStore(false_value, ret_value_addr);
   1682 
   1683         SetToFirstInsertPoint(if_data.after_block, ir_builder_);
   1684         return ir_builder_->CreateLoad(ret_value_addr);
   1685       };
   1686     case HloOpcode::kReshape:
   1687       CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
   1688                ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
   1689       return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
   1690         const HloInstruction* operand = hlo->operand(0);
   1691         return operand_to_generator.at(operand)(index.SourceIndexOfReshape(
   1692             hlo->shape(), operand->shape(), ir_builder_));
   1693       };
   1694     case HloOpcode::kTranspose:
   1695       return [this, hlo,
   1696               &operand_to_generator](const IrArray::Index& target_index) {
   1697         return operand_to_generator.at(hlo->operand(0))(
   1698             target_index.SourceIndexOfTranspose(
   1699                 hlo->shape(), hlo->operand(0)->shape(), hlo->dimensions(),
   1700                 ir_builder_));
   1701       };
   1702     case HloOpcode::kRng:
   1703       return MakeRngElementGenerator(hlo, operand_to_generator);
   1704     case HloOpcode::kPad:
   1705       return [=, &operand_to_generator](
   1706                  const IrArray::Index& padded_index) -> StatusOr<llvm::Value*> {
   1707         auto index = padded_index;
   1708         llvm::Value* in_bounds = ir_builder_->getTrue();
   1709         for (size_t i = 0; i < index.size(); ++i) {
   1710           auto index_typed_const = [=](int64 n) {
   1711             return llvm::ConstantInt::get(index[i]->getType(), n);
   1712           };
   1713           const auto& pad_dim = hlo->padding_config().dimensions(i);
   1714           index[i] = ir_builder_->CreateSub(
   1715               index[i], index_typed_const(pad_dim.edge_padding_low()));
   1716           in_bounds = ir_builder_->CreateAnd(
   1717               in_bounds,
   1718               ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)),
   1719               "in_bounds");
   1720           in_bounds = ir_builder_->CreateAnd(
   1721               in_bounds,
   1722               ir_builder_->CreateICmpEQ(
   1723                   index_typed_const(0),
   1724                   ir_builder_->CreateURem(
   1725                       index[i],
   1726                       index_typed_const(pad_dim.interior_padding() + 1))),
   1727               "in_bounds");
   1728           index[i] = ir_builder_->CreateSDiv(
   1729               index[i], index_typed_const(pad_dim.interior_padding() + 1));
   1730           in_bounds = ir_builder_->CreateAnd(
   1731               in_bounds,
   1732               ir_builder_->CreateICmpSLT(
   1733                   index[i],
   1734                   index_typed_const(hlo->operand(0)->shape().dimensions(i))),
   1735               "in_bounds");
   1736         }
   1737 
   1738         // if (in_bounds) {
   1739         //   ret_value = operand0[index];  // source
   1740         // } else {
   1741         //   ret_value = *operand1;        // padding
   1742         // }
   1743         llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
   1744             llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
   1745                                            module_),
   1746             "pad_result_addr", ir_builder_);
   1747         llvm_ir::LlvmIfData if_data =
   1748             llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
   1749         SetToFirstInsertPoint(if_data.true_block, ir_builder_);
   1750         TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
   1751                             operand_to_generator.at(hlo->operand(0))(index));
   1752         ir_builder_->CreateStore(operand_value, ret_value_addr);
   1753 
   1754         SetToFirstInsertPoint(if_data.false_block, ir_builder_);
   1755         TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
   1756                             operand_to_generator.at(hlo->operand(1))({}));
   1757         ir_builder_->CreateStore(padding_value, ret_value_addr);
   1758 
   1759         SetToFirstInsertPoint(if_data.after_block, ir_builder_);
   1760         // Don't create phi(operand_value, padding_value) here, because invoking
   1761         // operand_to_generator may create new basic blocks, making the parent
   1762         // of operand_value or padding_value no longer a predecessor of
   1763         // if_data.after_block.
   1764         return ir_builder_->CreateLoad(ret_value_addr);
   1765       };
   1766 
   1767     case HloOpcode::kDot:
   1768       return [=, &operand_to_generator](const IrArray::Index& dot_result_index)
   1769                  -> StatusOr<llvm::Value*> {
   1770         auto lhs_generator = operand_to_generator.at(hlo->operand(0));
   1771         auto rhs_generator = operand_to_generator.at(hlo->operand(1));
   1772         int64 contracted_dim_size = hlo->operand(0)->shape().dimensions(
   1773             hlo->operand(0)->shape().dimensions_size() - 1);
   1774         int64 lhs_dims = hlo->operand(0)->shape().dimensions_size();
   1775         int64 rhs_dims = hlo->operand(1)->shape().dimensions_size();
   1776 
   1777         std::unique_ptr<llvm_ir::ForLoop> inner_loop =
   1778             llvm_ir::ForLoop::EmitForLoop(
   1779                 IrName(hlo, "inner"), ir_builder_->getInt64(0),
   1780                 ir_builder_->getInt64(contracted_dim_size),
   1781                 ir_builder_->getInt64(1), ir_builder_);
   1782 
   1783         SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(),
   1784                               ir_builder_);
   1785         PrimitiveType primitive_type = hlo->shape().element_type();
   1786         llvm::Type* primitive_type_llvm =
   1787             llvm_ir::PrimitiveTypeToIrType(primitive_type, module_);
   1788         llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry(
   1789             primitive_type_llvm, "dot_acc", ir_builder_);
   1790         ir_builder_->CreateStore(
   1791             llvm::Constant::getNullValue(primitive_type_llvm),
   1792             accumulator_alloca);
   1793 
   1794         SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), ir_builder_);
   1795 
   1796         // This is the inner reduction loop for a dot operation that produces
   1797         // one element in the output.  If the operands to the dot operation have
   1798         // shapes [A,B,C,T] and [D,T,E], the result has a shape [A,B,C,D,E].
   1799         // Given an output index [a,b,c,d,e] in the result, we compute:
   1800         //   sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T))
   1801 
   1802         IrArray::Index lhs_index, rhs_index;
   1803 
   1804         for (int64 i = 0; i < lhs_dims - 1; i++) {
   1805           lhs_index.push_back(dot_result_index[i]);
   1806         }
   1807         lhs_index.push_back(inner_loop->GetIndVarValue());
   1808 
   1809         for (int64 i = 0; i < rhs_dims - 2; i++) {
   1810           rhs_index.push_back(dot_result_index[lhs_dims - 1 + i]);
   1811         }
   1812         rhs_index.push_back(inner_loop->GetIndVarValue());
   1813         rhs_index.push_back(dot_result_index.back());
   1814 
   1815         llvm::Value* current_accumulator =
   1816             ir_builder_->CreateLoad(accumulator_alloca);
   1817         TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
   1818         TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
   1819         llvm::Value* next_accumulator;
   1820         if (primitive_util::IsComplexType(primitive_type)) {
   1821           llvm::Value* product_real = ir_builder_->CreateFSub(
   1822               ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
   1823                                       EmitExtractReal(rhs_value)),
   1824               ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
   1825                                       EmitExtractImag(rhs_value)));
   1826           llvm::Value* product_imag = ir_builder_->CreateFAdd(
   1827               ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
   1828                                       EmitExtractImag(rhs_value)),
   1829               ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
   1830                                       EmitExtractReal(rhs_value)));
   1831           next_accumulator = ir_builder_->CreateInsertValue(
   1832               current_accumulator,
   1833               ir_builder_->CreateFAdd(EmitExtractReal(current_accumulator),
   1834                                       product_real),
   1835               {0});
   1836           next_accumulator = ir_builder_->CreateInsertValue(
   1837               next_accumulator,
   1838               ir_builder_->CreateFAdd(EmitExtractImag(current_accumulator),
   1839                                       product_imag),
   1840               {1});
   1841         } else if (primitive_util::IsFloatingPointType(primitive_type)) {
   1842           next_accumulator = ir_builder_->CreateFAdd(
   1843               current_accumulator,
   1844               ir_builder_->CreateFMul(lhs_value, rhs_value));
   1845         } else {
   1846           next_accumulator = ir_builder_->CreateAdd(
   1847               current_accumulator,
   1848               ir_builder_->CreateMul(lhs_value, rhs_value));
   1849         }
   1850         ir_builder_->CreateStore(next_accumulator, accumulator_alloca);
   1851 
   1852         SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), ir_builder_);
   1853         return ir_builder_->CreateLoad(accumulator_alloca);
   1854       };
   1855     default:
   1856       return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
   1857         return Unimplemented("Unhandled opcode for elemental IR emission: %s",
   1858                              HloOpcodeString(hlo->opcode()).c_str());
   1859       };
   1860   }
   1861 }
   1862 
   1863 llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) const {
   1864   return ir_builder_->CreateExtractValue(value, {0});
   1865 }
   1866 
   1867 llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) const {
   1868   return ir_builder_->CreateExtractValue(value, {1});
   1869 }
   1870 
   1871 llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op,
   1872                                                     llvm::Value* real,
   1873                                                     llvm::Value* imag) const {
   1874   auto cplx_type =
   1875       llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
   1876   auto complex = ir_builder_->CreateInsertValue(
   1877       llvm::ConstantAggregateZero::get(cplx_type), real, {0});
   1878   if (imag != nullptr) {
   1879     complex = ir_builder_->CreateInsertValue(complex, imag, {1});
   1880   }
   1881   return complex;
   1882 }
   1883 
   1884 }  // namespace xla
   1885