Home | History | Annotate | Download | only in llvm_ir
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
     17 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
     18 #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
     19 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
     20 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
     21 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
     22 
     23 namespace xla {
     24 namespace llvm_ir {
     25 
     26 bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice,
     27                                   const BufferAssignment& assignment) {
     28   CHECK_EQ(HloOpcode::kDynamicUpdateSlice, dynamic_update_slice->opcode());
     29   const HloInstruction* operand = dynamic_update_slice->operand(0);
     30   return assignment.HasTopLevelAllocation(dynamic_update_slice) &&
     31          assignment.HasTopLevelAllocation(operand) &&
     32          assignment.SharesTopLevelSlice(dynamic_update_slice, operand);
     33 }
     34 
     35 // Shared implementation of EmitDynamicUpdateSliceInPlace and
     36 // EmitFusedDynamicUpdateSliceInPlace.
     37 //
     38 // Emits a sequential loop if launch_dimensions is null.
     39 using IndexGenerator = std::function<StatusOr<llvm::Value*>(int64)>;
     40 
     41 static Status EmitDynamicUpdateSliceInPlaceImpl(
     42     const Shape& update_shape, const IndexGenerator& start_indices_generator,
     43     bool is_signed, ElementGenerator update_array_generator,
     44     const IrArray& output_array, const gpu::LaunchDimensions* launch_dimensions,
     45     absl::string_view name, llvm::IRBuilder<>* b) {
     46   const Shape& output_shape = output_array.GetShape();
     47 
     48   // Read start indices from start_indices_generator.
     49   const int64 rank = output_shape.rank();
     50   std::vector<llvm::Value*> start_multi_index(rank);
     51   for (int64 i = 0; i < rank; ++i) {
     52     TF_ASSIGN_OR_RETURN(start_multi_index[i], start_indices_generator(i));
     53     llvm::Value* output_dim_size = llvm::ConstantInt::get(
     54         start_multi_index[i]->getType(), output_shape.dimensions(i));
     55     llvm::Value* update_dim_size = llvm::ConstantInt::get(
     56         start_multi_index[i]->getType(), update_shape.dimensions(i));
     57 
     58     // Clamp the start index so that the update region fits in the operand.
     59     // start_index = clamp(start_index, 0, output_dim_size - update_dim_size)
     60     llvm::Value* max_bound = b->CreateSub(output_dim_size, update_dim_size);
     61     llvm::Value* zero =
     62         llvm::ConstantInt::get(start_multi_index[i]->getType(), 0);
     63     start_multi_index[i] =
     64         b->CreateSelect(b->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE
     65                                                 : llvm::ICmpInst::ICMP_UGE,
     66                                       zero, start_multi_index[i]),
     67                         zero, start_multi_index[i]);
     68 
     69     start_multi_index[i] =
     70         b->CreateSelect(b->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE
     71                                                 : llvm::ICmpInst::ICMP_ULE,
     72                                       max_bound, start_multi_index[i]),
     73                         max_bound, start_multi_index[i]);
     74   }
     75 
     76   auto loop_body_emitter = [&](const IrArray::Index& update_index) -> Status {
     77     // Calculate output_index, where we'll write the value from update.  For
     78     // each dimension,
     79     //
     80     //   output_index[dim] = start_index[dim] + update_index[dim]
     81     //
     82     std::vector<llvm::Value*> output_multi_index(rank);
     83     for (int64 i = 0; i < rank; ++i) {
     84       llvm::Value* start_index0 = b->CreateSExtOrBitCast(
     85           start_multi_index[i], update_index[i]->getType());
     86       output_multi_index[i] = b->CreateAdd(start_index0, update_index[i]);
     87     }
     88 
     89     // Do output[output_index] = update[update_index].
     90     IrArray::Index output_index(output_multi_index, output_shape,
     91                                 b->getInt64Ty());
     92     TF_ASSIGN_OR_RETURN(llvm::Value * update_data,
     93                         update_array_generator(update_index));
     94     output_array.EmitWriteArrayElement(output_index, update_data, b);
     95     return Status::OK();
     96   };
     97 
     98   if (launch_dimensions != nullptr) {
     99     return gpu::ParallelLoopEmitter(loop_body_emitter, update_shape,
    100                                     *launch_dimensions, b)
    101         .EmitLoop(name);
    102   }
    103   return LoopEmitter(loop_body_emitter, update_shape, b).EmitLoop(name);
    104 }
    105 
    106 Status EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays,
    107                                      const IrArray& output_array,
    108                                      absl::string_view name,
    109                                      llvm::IRBuilder<>* b) {
    110   VLOG(2) << "EmitDynamicUpdateSliceInPlace for " << name;
    111 
    112   // No need to use operand_arrays[0], the input array of the
    113   // dynamic-update-slice, because we know it aliases the op's output.
    114   IrArray update_array = operand_arrays[1];
    115   IrArray start_indices_array = operand_arrays[2];
    116   Shape output_shape = output_array.GetShape();
    117   Shape update_shape = update_array.GetShape();
    118 
    119   IndexGenerator start_indices_generator = [&](int64 index) {
    120     return operand_arrays[2 + index].EmitReadArrayElement(
    121         IrArray::Index(b->getInt64Ty()), b);
    122   };
    123   ElementGenerator update_array_generator = [&](const IrArray::Index& index) {
    124     return update_array.EmitReadArrayElement(index, b);
    125   };
    126 
    127   bool is_signed = ShapeUtil::ElementIsSigned(start_indices_array.GetShape());
    128   return EmitDynamicUpdateSliceInPlaceImpl(
    129       update_shape, start_indices_generator, is_signed, update_array_generator,
    130       output_array, /*launch_dimensions=*/nullptr, name, b);
    131 }
    132 
    133 // Shared implementation for EmitFusedDynamicUpdateSliceInPlace and
    134 // EmitParallelFusedDynamicUpdateSliceInPlace.
    135 //
    136 // Emits a sequential loop if launch_dimensions is null.
    137 static Status EmitFusedDynamicUpdateSliceInPlaceImpl(
    138     HloInstruction* fusion,
    139     GeneratorForOperandIrArrays operand_arrays_generator,
    140     const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
    141     const gpu::LaunchDimensions* launch_dimensions, llvm::IRBuilder<>* b) {
    142   CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
    143   VLOG(2) << "EmitFusedDynamicUpdateSliceInPlace for "
    144           << fusion->ToShortString();
    145 
    146   auto* dynamic_update_slice = fusion->fused_expression_root();
    147 
    148   const auto* update = dynamic_update_slice->operand(1);
    149   const auto* start_indices = dynamic_update_slice->operand(2);
    150   Shape update_shape = update->shape();
    151 
    152   // Our in-place dynamic-update-slice implementation emits a loop over
    153   // update_shape.  To emit a cache-friendly loop, we need to know that shape's
    154   // layout.
    155   //
    156   // update_shape is inside a fusion node -- it's never materialized in memory
    157   // and thus doesn't have a layout.  In this case we use the layout of the
    158   // fusion node for iteration, since that corresponds to the order in memory of
    159   // the buffer we'll be writing to.
    160   //
    161   // (This isn't necessarily optimal; in some cases it might be faster to peek
    162   // through the chain of ops that gives us the update operand and use the
    163   // layout of its source buffer(s).  But this is no worse than we do with
    164   // fusion elsewhere.)
    165   TF_RETURN_IF_ERROR(
    166       LayoutUtil::CopyLayoutBetweenShapes(fusion->shape(), &update_shape));
    167 
    168   // Create element generators for update and start_indices.
    169   FusedIrEmitter fused_emitter(std::move(operand_arrays_generator),
    170                                elemental_emitter);
    171   TF_RETURN_IF_ERROR(dynamic_update_slice->Accept(&fused_emitter));
    172   ElementGenerator update_array_generator = fused_emitter.GetGenerator(update);
    173 
    174   IndexGenerator start_indices_generator = [&](int64 index) {
    175     ElementGenerator element_generator =
    176         fused_emitter.GetGenerator(dynamic_update_slice->operand(2 + index));
    177     return element_generator(IrArray::Index(b->getInt64Ty()));
    178   };
    179   bool is_signed = ShapeUtil::ElementIsSigned(start_indices->shape());
    180   return EmitDynamicUpdateSliceInPlaceImpl(
    181       update_shape, start_indices_generator, is_signed, update_array_generator,
    182       fusion_output_array, launch_dimensions, IrName(fusion), b);
    183 }
    184 
    185 Status EmitFusedDynamicUpdateSliceInPlace(
    186     HloInstruction* fusion,
    187     GeneratorForOperandIrArrays operand_arrays_generator,
    188     const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
    189     llvm::IRBuilder<>* b) {
    190   return EmitFusedDynamicUpdateSliceInPlaceImpl(
    191       fusion, std::move(operand_arrays_generator), fusion_output_array,
    192       elemental_emitter,
    193       /*launch_dimensions=*/nullptr, b);
    194 }
    195 
    196 Status EmitParallelFusedDynamicUpdateSliceInPlace(
    197     HloInstruction* fusion,
    198     GeneratorForOperandIrArrays operand_arrays_generator,
    199     const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
    200     const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b) {
    201   return EmitFusedDynamicUpdateSliceInPlaceImpl(
    202       fusion, std::move(operand_arrays_generator), fusion_output_array,
    203       elemental_emitter, &launch_dimensions, b);
    204 }
    205 
    206 }  // namespace llvm_ir
    207 }  // namespace xla
    208