1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h" 17 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" 18 #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" 19 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" 20 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 21 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" 22 23 namespace xla { 24 namespace llvm_ir { 25 26 bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice, 27 const BufferAssignment& assignment) { 28 CHECK_EQ(HloOpcode::kDynamicUpdateSlice, dynamic_update_slice->opcode()); 29 const HloInstruction* operand = dynamic_update_slice->operand(0); 30 return assignment.HasTopLevelAllocation(dynamic_update_slice) && 31 assignment.HasTopLevelAllocation(operand) && 32 assignment.SharesTopLevelSlice(dynamic_update_slice, operand); 33 } 34 35 // Shared implementation of EmitDynamicUpdateSliceInPlace and 36 // EmitFusedDynamicUpdateSliceInPlace. 37 // 38 // Emits a sequential loop if launch_dimensions is null. 39 using IndexGenerator = std::function<StatusOr<llvm::Value*>(int64)>; 40 41 static Status EmitDynamicUpdateSliceInPlaceImpl( 42 const Shape& update_shape, const IndexGenerator& start_indices_generator, 43 bool is_signed, ElementGenerator update_array_generator, 44 const IrArray& output_array, const gpu::LaunchDimensions* launch_dimensions, 45 absl::string_view name, llvm::IRBuilder<>* b) { 46 const Shape& output_shape = output_array.GetShape(); 47 48 // Read start indices from start_indices_generator. 49 const int64 rank = output_shape.rank(); 50 std::vector<llvm::Value*> start_multi_index(rank); 51 for (int64 i = 0; i < rank; ++i) { 52 TF_ASSIGN_OR_RETURN(start_multi_index[i], start_indices_generator(i)); 53 llvm::Value* output_dim_size = llvm::ConstantInt::get( 54 start_multi_index[i]->getType(), output_shape.dimensions(i)); 55 llvm::Value* update_dim_size = llvm::ConstantInt::get( 56 start_multi_index[i]->getType(), update_shape.dimensions(i)); 57 58 // Clamp the start index so that the update region fits in the operand. 59 // start_index = clamp(start_index, 0, output_dim_size - update_dim_size) 60 llvm::Value* max_bound = b->CreateSub(output_dim_size, update_dim_size); 61 llvm::Value* zero = 62 llvm::ConstantInt::get(start_multi_index[i]->getType(), 0); 63 start_multi_index[i] = 64 b->CreateSelect(b->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE 65 : llvm::ICmpInst::ICMP_UGE, 66 zero, start_multi_index[i]), 67 zero, start_multi_index[i]); 68 69 start_multi_index[i] = 70 b->CreateSelect(b->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE 71 : llvm::ICmpInst::ICMP_ULE, 72 max_bound, start_multi_index[i]), 73 max_bound, start_multi_index[i]); 74 } 75 76 auto loop_body_emitter = [&](const IrArray::Index& update_index) -> Status { 77 // Calculate output_index, where we'll write the value from update. For 78 // each dimension, 79 // 80 // output_index[dim] = start_index[dim] + update_index[dim] 81 // 82 std::vector<llvm::Value*> output_multi_index(rank); 83 for (int64 i = 0; i < rank; ++i) { 84 llvm::Value* start_index0 = b->CreateSExtOrBitCast( 85 start_multi_index[i], update_index[i]->getType()); 86 output_multi_index[i] = b->CreateAdd(start_index0, update_index[i]); 87 } 88 89 // Do output[output_index] = update[update_index]. 90 IrArray::Index output_index(output_multi_index, output_shape, 91 b->getInt64Ty()); 92 TF_ASSIGN_OR_RETURN(llvm::Value * update_data, 93 update_array_generator(update_index)); 94 output_array.EmitWriteArrayElement(output_index, update_data, b); 95 return Status::OK(); 96 }; 97 98 if (launch_dimensions != nullptr) { 99 return gpu::ParallelLoopEmitter(loop_body_emitter, update_shape, 100 *launch_dimensions, b) 101 .EmitLoop(name); 102 } 103 return LoopEmitter(loop_body_emitter, update_shape, b).EmitLoop(name); 104 } 105 106 Status EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays, 107 const IrArray& output_array, 108 absl::string_view name, 109 llvm::IRBuilder<>* b) { 110 VLOG(2) << "EmitDynamicUpdateSliceInPlace for " << name; 111 112 // No need to use operand_arrays[0], the input array of the 113 // dynamic-update-slice, because we know it aliases the op's output. 114 IrArray update_array = operand_arrays[1]; 115 IrArray start_indices_array = operand_arrays[2]; 116 Shape output_shape = output_array.GetShape(); 117 Shape update_shape = update_array.GetShape(); 118 119 IndexGenerator start_indices_generator = [&](int64 index) { 120 return operand_arrays[2 + index].EmitReadArrayElement( 121 IrArray::Index(b->getInt64Ty()), b); 122 }; 123 ElementGenerator update_array_generator = [&](const IrArray::Index& index) { 124 return update_array.EmitReadArrayElement(index, b); 125 }; 126 127 bool is_signed = ShapeUtil::ElementIsSigned(start_indices_array.GetShape()); 128 return EmitDynamicUpdateSliceInPlaceImpl( 129 update_shape, start_indices_generator, is_signed, update_array_generator, 130 output_array, /*launch_dimensions=*/nullptr, name, b); 131 } 132 133 // Shared implementation for EmitFusedDynamicUpdateSliceInPlace and 134 // EmitParallelFusedDynamicUpdateSliceInPlace. 135 // 136 // Emits a sequential loop if launch_dimensions is null. 137 static Status EmitFusedDynamicUpdateSliceInPlaceImpl( 138 HloInstruction* fusion, 139 GeneratorForOperandIrArrays operand_arrays_generator, 140 const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, 141 const gpu::LaunchDimensions* launch_dimensions, llvm::IRBuilder<>* b) { 142 CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); 143 VLOG(2) << "EmitFusedDynamicUpdateSliceInPlace for " 144 << fusion->ToShortString(); 145 146 auto* dynamic_update_slice = fusion->fused_expression_root(); 147 148 const auto* update = dynamic_update_slice->operand(1); 149 const auto* start_indices = dynamic_update_slice->operand(2); 150 Shape update_shape = update->shape(); 151 152 // Our in-place dynamic-update-slice implementation emits a loop over 153 // update_shape. To emit a cache-friendly loop, we need to know that shape's 154 // layout. 155 // 156 // update_shape is inside a fusion node -- it's never materialized in memory 157 // and thus doesn't have a layout. In this case we use the layout of the 158 // fusion node for iteration, since that corresponds to the order in memory of 159 // the buffer we'll be writing to. 160 // 161 // (This isn't necessarily optimal; in some cases it might be faster to peek 162 // through the chain of ops that gives us the update operand and use the 163 // layout of its source buffer(s). But this is no worse than we do with 164 // fusion elsewhere.) 165 TF_RETURN_IF_ERROR( 166 LayoutUtil::CopyLayoutBetweenShapes(fusion->shape(), &update_shape)); 167 168 // Create element generators for update and start_indices. 169 FusedIrEmitter fused_emitter(std::move(operand_arrays_generator), 170 elemental_emitter); 171 TF_RETURN_IF_ERROR(dynamic_update_slice->Accept(&fused_emitter)); 172 ElementGenerator update_array_generator = fused_emitter.GetGenerator(update); 173 174 IndexGenerator start_indices_generator = [&](int64 index) { 175 ElementGenerator element_generator = 176 fused_emitter.GetGenerator(dynamic_update_slice->operand(2 + index)); 177 return element_generator(IrArray::Index(b->getInt64Ty())); 178 }; 179 bool is_signed = ShapeUtil::ElementIsSigned(start_indices->shape()); 180 return EmitDynamicUpdateSliceInPlaceImpl( 181 update_shape, start_indices_generator, is_signed, update_array_generator, 182 fusion_output_array, launch_dimensions, IrName(fusion), b); 183 } 184 185 Status EmitFusedDynamicUpdateSliceInPlace( 186 HloInstruction* fusion, 187 GeneratorForOperandIrArrays operand_arrays_generator, 188 const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, 189 llvm::IRBuilder<>* b) { 190 return EmitFusedDynamicUpdateSliceInPlaceImpl( 191 fusion, std::move(operand_arrays_generator), fusion_output_array, 192 elemental_emitter, 193 /*launch_dimensions=*/nullptr, b); 194 } 195 196 Status EmitParallelFusedDynamicUpdateSliceInPlace( 197 HloInstruction* fusion, 198 GeneratorForOperandIrArrays operand_arrays_generator, 199 const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, 200 const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b) { 201 return EmitFusedDynamicUpdateSliceInPlaceImpl( 202 fusion, std::move(operand_arrays_generator), fusion_output_array, 203 elemental_emitter, &launch_dimensions, b); 204 } 205 206 } // namespace llvm_ir 207 } // namespace xla 208