1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <iterator> 17 18 #include "tensorflow/compiler/xla/service/cpu/ir_function.h" 19 20 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" 21 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" 22 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 23 #include "tensorflow/compiler/xla/status_macros.h" 24 25 namespace xla { 26 27 namespace { 28 using llvm_ir::AsStringRef; 29 } // namespace 30 31 namespace cpu { 32 33 static std::vector<llvm::Type*> GetComputeFunctionParams( 34 llvm::Module* llvm_module, const int64 num_dynamic_loop_bounds) { 35 llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(llvm_module->getContext()); 36 llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo(); 37 llvm::Type* i64_ptr_type = 38 llvm::Type::getInt64PtrTy(llvm_module->getContext()); 39 std::vector<llvm::Type*> compute_function_params( 40 {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type}); 41 if (num_dynamic_loop_bounds > 0) { 42 compute_function_params.push_back(i64_ptr_type); 43 } 44 compute_function_params.push_back(i64_ptr_type); 45 return compute_function_params; 46 } 47 48 IrFunction::IrFunction(const string& function_name, 49 llvm::Function::LinkageTypes linkage, 50 const bool optimize_for_size_requested, 51 const bool enable_fast_math, llvm::Module* llvm_module, 52 llvm::IRBuilder<>* ir_builder, 53 int64 num_dynamic_loop_bounds) 54 : ir_builder_(ir_builder), 55 llvm_module_(llvm_module), 56 caller_insert_point_guard_(*ir_builder), 57 num_dynamic_loop_bounds_(num_dynamic_loop_bounds) { 58 Initialize(function_name, linkage, optimize_for_size_requested, 59 enable_fast_math); 60 } 61 62 IrFunction::~IrFunction() { 63 // Emit function return value. 64 ir_builder_->CreateRetVoid(); 65 } 66 67 DynamicLoopBounds IrFunction::GetDynamicLoopBounds() { 68 DynamicLoopBounds dynamic_loop_bounds(num_dynamic_loop_bounds_); 69 for (int i = 0; i < num_dynamic_loop_bounds_; ++i) { 70 dynamic_loop_bounds[i].first = GetDynamicLoopBound(i * 2 + 0); 71 dynamic_loop_bounds[i].second = GetDynamicLoopBound(i * 2 + 1); 72 } 73 return dynamic_loop_bounds; 74 } 75 76 void IrFunction::Initialize(const string& function_name, 77 llvm::Function::LinkageTypes linkage, 78 const bool optimize_for_size_requested, 79 const bool enable_fast_math) { 80 // The function signature is: 81 // void function(i8* retval, i8* run_options, i8** params, i8** temps, 82 // i64* dynamic_loop_bounds, i64* prof_counters) 83 // 84 // retval: points to the returned value. 85 // params: address of an array with pointers to parameters. 86 // temps: address of an array with pointers to temporary buffers. 87 // 88 // Therefore, the generated function's signature (FunctionType) is statically 89 // determined - parameter unpacking is done in code generated into the 90 // function, rather than by a prologue dictated by the platform ABI. 91 // 92 // /--------------\ 93 // retval ----------> | return value | 94 // \--------------/ 95 // 96 // /-------------------------------\ 97 // run_options -----> | xla::ExecutableRunOptions | 98 // \-------------------------------/ 99 // 100 // /---------------------------------------------\ 101 // params --------> | param 0 | param 1 | ..... | param N-1 | 102 // | addr | addr | | addr | 103 // \---------------------------------------------/ 104 // | | | 105 // | | | 106 // V V V 107 // /---------\ /---------\ /-----------\ 108 // | param 0 | | param 1 | | param N-1 | 109 // \---------/ \---------/ \-----------/ 110 // 111 // /---------------------------------------------\ 112 // temps ---------> | temp 0 | temp 1 | ..... | temp N-1 | 113 // | addr | addr | | addr | 114 // \---------------------------------------------/ 115 // | | | 116 // | | | 117 // V V V 118 // /---------\ /---------\ /-----------\ 119 // | temp 0 | | temp 1 | | temp N-1 | 120 // \---------/ \---------/ \-----------/ 121 // 122 // /--------------------------------------------\ 123 // dynamic loop bounds -> | outer_dim0_start | outer_dim0_limit | .....| 124 // (elided for aot) \--------------------------------------------/ 125 // 126 // /---------------------------------------------\ 127 // prof counters -> | counter 0 | counter 1 | ..... | counter N-1 | 128 // \---------------------------------------------/ 129 130 // Even though the type of params and temps is void** in the host's view, in 131 // LLVM IR this is represented by i8*, similarly to void*. It's up to the code 132 // to use GEPs to unravel the indirection layers. 133 llvm::FunctionType* function_type = llvm::FunctionType::get( 134 /*Result=*/llvm::Type::getVoidTy(llvm_module_->getContext()), 135 /*Params=*/ 136 GetComputeFunctionParams(llvm_module_, num_dynamic_loop_bounds_), 137 /*isVarArg=*/false); 138 139 // Functions with local linkage get an inlining bonus. Because we know 140 // a-priori that embedded functions (non-entry functions) will not have its 141 // name resolved, give it local linkage. 142 function_ = 143 llvm_ir::CreateFunction(function_type, linkage, 144 /*enable_fast_math=*/enable_fast_math, 145 /*optimize_for_size=*/optimize_for_size_requested, 146 function_name, llvm_module_); 147 148 // Set meaningful names for the function's arguments: useful for debugging. 149 llvm::Function::arg_iterator arg_iter = function_->arg_begin(); 150 arg_iter->setName("retval"); 151 result_arg_ = &*arg_iter; 152 (++arg_iter)->setName("run_options"); 153 exec_run_options_arg_ = &*arg_iter; 154 (++arg_iter)->setName("params"); 155 parameters_arg_ = &*arg_iter; 156 (++arg_iter)->setName("temps"); 157 temp_buffers_arg_ = &*arg_iter; 158 if (num_dynamic_loop_bounds_ > 0) { 159 (++arg_iter)->setName("dynamic_loop_bounds"); 160 dynamic_loop_bounds_arg_ = &*arg_iter; 161 } 162 (++arg_iter)->setName("prof_counters"); 163 profile_counters_arg_ = &*arg_iter; 164 165 // We know a-priori that the function arguments are guaranteed to point to 166 // disjoint objects. 167 llvm::Argument* retval = result_arg(); 168 for (llvm::Argument& argument : function_->args()) { 169 // However, the return buffer aliases the temporaries and thus cannot be 170 // marked noalias. 171 if (&argument == retval) { 172 continue; 173 } 174 function_->addAttribute(argument.getArgNo() + 1, llvm::Attribute::NoAlias); 175 } 176 177 ir_builder_->SetInsertPoint(llvm::BasicBlock::Create( 178 /*Context=*/llvm_module_->getContext(), 179 /*Name=*/"entry", 180 /*Parent=*/function_)); 181 } 182 183 llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { 184 CHECK_GT(num_dynamic_loop_bounds_, 0); 185 CHECK_LT(offset, num_dynamic_loop_bounds_ * 2); 186 string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset); 187 return ir_builder_->CreateLoad( 188 ir_builder_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_), 189 ir_builder_->getInt64(offset), AsStringRef(name))); 190 } 191 192 // Emits code to allocate an array of parameter address pointers, and store 193 // each address from 'parameter_addresses'. 194 // Returns an array of compute function call arguments (including parameter 195 // address buffer). 196 std::vector<llvm::Value*> GetArrayFunctionCallArguments( 197 tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, 198 llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece name, 199 llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, 200 llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) { 201 llvm::Value* parameter_addresses_buffer = 202 llvm_ir::EmitAllocaAtFunctionEntryWithCount( 203 ir_builder->getInt8PtrTy(), 204 ir_builder->getInt32(parameter_addresses.size()), 205 tensorflow::strings::StrCat(name, "_parameter_addresses"), 206 ir_builder); 207 for (size_t i = 0; i < parameter_addresses.size(); ++i) { 208 llvm::Value* parameter_as_i8ptr = ir_builder->CreateBitCast( 209 parameter_addresses[i], ir_builder->getInt8PtrTy(), 210 AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i, 211 "_address_as_i8ptr"))); 212 llvm::Value* slot_in_param_addresses = ir_builder->CreateInBoundsGEP( 213 parameter_addresses_buffer, {ir_builder->getInt64(i)}); 214 ir_builder->CreateStore(parameter_as_i8ptr, slot_in_param_addresses); 215 } 216 217 const auto to_int8_ptr = [=](llvm::Value* ptr) { 218 return ir_builder->CreatePointerCast(ptr, ir_builder->getInt8PtrTy()); 219 }; 220 std::vector<llvm::Value*> arguments{ 221 to_int8_ptr(return_value_buffer), to_int8_ptr(exec_run_options_arg), 222 parameter_addresses_buffer, temp_buffers_arg}; 223 if (profile_counters_arg != nullptr) { 224 arguments.push_back(profile_counters_arg); 225 } 226 return arguments; 227 } 228 229 // Emits a call to a runtime fork/join function which dispatches parallel 230 // calls to 'parallel_function' (and joins threads before returning). 231 Status EmitCallToParallelForkJoin( 232 const std::vector<llvm::Value*>& arguments, const Shape& shape, 233 const std::vector<int64>& dimension_partition_counts, 234 llvm::IRBuilder<>* ir_builder, llvm::Function* parallel_function, 235 const string& name) { 236 llvm::Module* module = ir_builder->GetInsertBlock()->getModule(); 237 238 // Build ParallelForkJoin function type. 239 std::vector<llvm::Type*> compute_function_params = 240 GetComputeFunctionParams(module, /*num_dynamic_loop_bounds=*/0); 241 // Number of parallel compute functions. 242 compute_function_params.push_back(ir_builder->getInt32Ty()); 243 // Array of partitions. There is an array element for each 244 // partition x partition_dim x 2 (for dimension start and limit). 245 compute_function_params.push_back( 246 llvm::Type::getInt64PtrTy(module->getContext())); 247 // Number of partitioned most-major dimensions in 'shape'. 248 compute_function_params.push_back(ir_builder->getInt32Ty()); 249 // Function pointer for compute function to be dispatched in parallel. 250 compute_function_params.push_back( 251 llvm::Type::getInt8PtrTy(module->getContext())); 252 253 llvm::FunctionType* fork_join_type = llvm::FunctionType::get( 254 /*Result=*/llvm::Type::getVoidTy(module->getContext()), 255 /*Params=*/compute_function_params, 256 /*isVarArg=*/false); 257 258 llvm::Function* fork_join_func = 259 llvm::cast<llvm::Function>(module->getOrInsertFunction( 260 runtime::kParallelForkJoinSymbolName, fork_join_type)); 261 fork_join_func->setCallingConv(llvm::CallingConv::C); 262 fork_join_func->setDoesNotThrow(); 263 264 // Add common compute function arguments. 265 std::vector<llvm::Value*> fork_join_arguments(arguments); 266 267 // Create ShapePartitionIterator to generate all partitions of 'shape'. 268 ShapePartitionIterator partition_iterator(shape, dimension_partition_counts); 269 const int64 num_partitions = partition_iterator.GetTotalPartitionCount(); 270 // Add argument specifying the number of parallel partitions. 271 fork_join_arguments.push_back(ir_builder->getInt32(num_partitions)); 272 273 // The number of partitioned most-major dimensions in 'shape'. 274 const int32 num_partitioned_dims = dimension_partition_counts.size(); 275 // A dimension partition consists of two elements: [start_index, limit_index). 276 const int32 dim_partition_size = 2; 277 // Calculate array partition stride. 278 const int32 array_partition_stride = 279 num_partitioned_dims * dim_partition_size; 280 // Calculate the total number of elements in the partition array. 281 const int32 partition_array_size = 282 dim_partition_size * num_partitioned_dims * num_partitions; 283 284 // Store dimension partition values as llvm constants in 'partitions'. 285 // See comments in runtime_fork_join.cc for array layout description. 286 std::vector<llvm::Constant*> partitions(partition_array_size); 287 for (int32 i = 0; i < num_partitions; ++i) { 288 std::vector<std::pair<int64, int64>> dim_partitions = 289 partition_iterator.GetPartition(i); 290 CHECK_EQ(num_partitioned_dims, dim_partitions.size()); 291 const int32 partition_index = i * array_partition_stride; 292 for (int32 j = 0; j < num_partitioned_dims; ++j) { 293 const std::pair<int64, int64>& dim_partition = dim_partitions[j]; 294 const int32 index = partition_index + j * dim_partition_size; 295 // Store partition [dim_start, dim_limit) intervals for each dimension. 296 partitions[index] = ir_builder->getInt64(dim_partition.first); 297 partitions[index + 1] = 298 ir_builder->getInt64(dim_partition.first + dim_partition.second); 299 } 300 } 301 302 // Create global variable out of dimension partitions in 'partitions'. 303 llvm::ArrayType* partitions_array_type = 304 llvm::ArrayType::get(ir_builder->getInt64Ty(), partition_array_size); 305 llvm::Constant* partitions_array = 306 llvm::ConstantArray::get(partitions_array_type, partitions); 307 llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable( 308 /*M=*/*module, 309 /*Ty=*/partitions_array_type, 310 /*isConstant=*/true, 311 /*Linkage=*/llvm::GlobalValue::PrivateLinkage, 312 /*Initializer=*/partitions_array, 313 /*Name=*/ 314 AsStringRef( 315 tensorflow::strings::StrCat(name, "_parallel_dimension_partitions"))); 316 317 // Add argument specifying parallel dimension partitions. 318 fork_join_arguments.push_back(ir_builder->CreateBitCast( 319 global_partitions_array, 320 llvm::Type::getInt64PtrTy(module->getContext()))); 321 // Add argument specifying the number of partitioned most-major dimensions. 322 fork_join_arguments.push_back(ir_builder->getInt32(num_partitioned_dims)); 323 // Add argument for parallel compute function pointer. 324 fork_join_arguments.push_back( 325 ir_builder->CreateBitCast(parallel_function, ir_builder->getInt8PtrTy())); 326 // Emit call to parallel fork/join. 327 ir_builder->CreateCall(fork_join_func, fork_join_arguments); 328 329 return Status::OK(); 330 } 331 332 } // namespace cpu 333 } // namespace xla 334