Home | History | Annotate | Download | only in cpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <iterator>
     17 
     18 #include "tensorflow/compiler/xla/service/cpu/ir_function.h"
     19 
     20 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
     21 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
     22 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
     23 #include "tensorflow/compiler/xla/status_macros.h"
     24 
     25 namespace xla {
     26 
     27 namespace {
     28 using llvm_ir::AsStringRef;
     29 }  // namespace
     30 
     31 namespace cpu {
     32 
     33 static std::vector<llvm::Type*> GetComputeFunctionParams(
     34     llvm::Module* llvm_module, const int64 num_dynamic_loop_bounds) {
     35   llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(llvm_module->getContext());
     36   llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo();
     37   llvm::Type* i64_ptr_type =
     38       llvm::Type::getInt64PtrTy(llvm_module->getContext());
     39   std::vector<llvm::Type*> compute_function_params(
     40       {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type});
     41   if (num_dynamic_loop_bounds > 0) {
     42     compute_function_params.push_back(i64_ptr_type);
     43   }
     44   compute_function_params.push_back(i64_ptr_type);
     45   return compute_function_params;
     46 }
     47 
     48 IrFunction::IrFunction(const string& function_name,
     49                        llvm::Function::LinkageTypes linkage,
     50                        const bool optimize_for_size_requested,
     51                        const bool enable_fast_math, llvm::Module* llvm_module,
     52                        llvm::IRBuilder<>* ir_builder,
     53                        int64 num_dynamic_loop_bounds)
     54     : ir_builder_(ir_builder),
     55       llvm_module_(llvm_module),
     56       caller_insert_point_guard_(*ir_builder),
     57       num_dynamic_loop_bounds_(num_dynamic_loop_bounds) {
     58   Initialize(function_name, linkage, optimize_for_size_requested,
     59              enable_fast_math);
     60 }
     61 
     62 IrFunction::~IrFunction() {
     63   // Emit function return value.
     64   ir_builder_->CreateRetVoid();
     65 }
     66 
     67 DynamicLoopBounds IrFunction::GetDynamicLoopBounds() {
     68   DynamicLoopBounds dynamic_loop_bounds(num_dynamic_loop_bounds_);
     69   for (int i = 0; i < num_dynamic_loop_bounds_; ++i) {
     70     dynamic_loop_bounds[i].first = GetDynamicLoopBound(i * 2 + 0);
     71     dynamic_loop_bounds[i].second = GetDynamicLoopBound(i * 2 + 1);
     72   }
     73   return dynamic_loop_bounds;
     74 }
     75 
     76 void IrFunction::Initialize(const string& function_name,
     77                             llvm::Function::LinkageTypes linkage,
     78                             const bool optimize_for_size_requested,
     79                             const bool enable_fast_math) {
     80   // The function signature is:
     81   //   void function(i8* retval, i8* run_options, i8** params, i8** temps,
     82   //                 i64* dynamic_loop_bounds, i64* prof_counters)
     83   //
     84   // retval: points to the returned value.
     85   // params: address of an array with pointers to parameters.
     86   // temps: address of an array with pointers to temporary buffers.
     87   //
     88   // Therefore, the generated function's signature (FunctionType) is statically
     89   // determined - parameter unpacking is done in code generated into the
     90   // function, rather than by a prologue dictated by the platform ABI.
     91   //
     92   //                      /--------------\
     93   //   retval ----------> | return value |
     94   //                      \--------------/
     95   //
     96   //                      /-------------------------------\
     97   //   run_options -----> | xla::ExecutableRunOptions |
     98   //                      \-------------------------------/
     99   //
    100   //                     /---------------------------------------------\
    101   //   params -------->  |  param 0  |  param 1  | ..... |  param N-1  |
    102   //                     |   addr    |   addr    |       |   addr      |
    103   //                     \---------------------------------------------/
    104   //                          |           |                   |
    105   //                          |           |                   |
    106   //                          V           V                   V
    107   //                     /---------\  /---------\         /-----------\
    108   //                     | param 0 |  | param 1 |         | param N-1 |
    109   //                     \---------/  \---------/         \-----------/
    110   //
    111   //                     /---------------------------------------------\
    112   //   temps --------->  |  temp  0  |  temp  1  | ..... |  temp  N-1  |
    113   //                     |   addr    |   addr    |       |   addr      |
    114   //                     \---------------------------------------------/
    115   //                          |           |                   |
    116   //                          |           |                   |
    117   //                          V           V                   V
    118   //                     /---------\  /---------\         /-----------\
    119   //                     | temp  0 |  | temp  1 |         | temp  N-1 |
    120   //                     \---------/  \---------/         \-----------/
    121   //
    122   //                        /--------------------------------------------\
    123   // dynamic loop bounds -> | outer_dim0_start | outer_dim0_limit | .....|
    124   //  (elided for aot)      \--------------------------------------------/
    125   //
    126   //                     /---------------------------------------------\
    127   //   prof counters ->  | counter 0 | counter 1 | ..... | counter N-1 |
    128   //                     \---------------------------------------------/
    129 
    130   // Even though the type of params and temps is void** in the host's view, in
    131   // LLVM IR this is represented by i8*, similarly to void*. It's up to the code
    132   // to use GEPs to unravel the indirection layers.
    133   llvm::FunctionType* function_type = llvm::FunctionType::get(
    134       /*Result=*/llvm::Type::getVoidTy(llvm_module_->getContext()),
    135       /*Params=*/
    136       GetComputeFunctionParams(llvm_module_, num_dynamic_loop_bounds_),
    137       /*isVarArg=*/false);
    138 
    139   // Functions with local linkage get an inlining bonus.  Because we know
    140   // a-priori that embedded functions (non-entry functions) will not have its
    141   // name resolved, give it local linkage.
    142   function_ =
    143       llvm_ir::CreateFunction(function_type, linkage,
    144                               /*enable_fast_math=*/enable_fast_math,
    145                               /*optimize_for_size=*/optimize_for_size_requested,
    146                               function_name, llvm_module_);
    147 
    148   // Set meaningful names for the function's arguments: useful for debugging.
    149   llvm::Function::arg_iterator arg_iter = function_->arg_begin();
    150   arg_iter->setName("retval");
    151   result_arg_ = &*arg_iter;
    152   (++arg_iter)->setName("run_options");
    153   exec_run_options_arg_ = &*arg_iter;
    154   (++arg_iter)->setName("params");
    155   parameters_arg_ = &*arg_iter;
    156   (++arg_iter)->setName("temps");
    157   temp_buffers_arg_ = &*arg_iter;
    158   if (num_dynamic_loop_bounds_ > 0) {
    159     (++arg_iter)->setName("dynamic_loop_bounds");
    160     dynamic_loop_bounds_arg_ = &*arg_iter;
    161   }
    162   (++arg_iter)->setName("prof_counters");
    163   profile_counters_arg_ = &*arg_iter;
    164 
    165   // We know a-priori that the function arguments are guaranteed to point to
    166   // disjoint objects.
    167   llvm::Argument* retval = result_arg();
    168   for (llvm::Argument& argument : function_->args()) {
    169     // However, the return buffer aliases the temporaries and thus cannot be
    170     // marked noalias.
    171     if (&argument == retval) {
    172       continue;
    173     }
    174     function_->addAttribute(argument.getArgNo() + 1, llvm::Attribute::NoAlias);
    175   }
    176 
    177   ir_builder_->SetInsertPoint(llvm::BasicBlock::Create(
    178       /*Context=*/llvm_module_->getContext(),
    179       /*Name=*/"entry",
    180       /*Parent=*/function_));
    181 }
    182 
    183 llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) {
    184   CHECK_GT(num_dynamic_loop_bounds_, 0);
    185   CHECK_LT(offset, num_dynamic_loop_bounds_ * 2);
    186   string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset);
    187   return ir_builder_->CreateLoad(
    188       ir_builder_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_),
    189                              ir_builder_->getInt64(offset), AsStringRef(name)));
    190 }
    191 
    192 // Emits code to allocate an array of parameter address pointers, and store
    193 // each address from 'parameter_addresses'.
    194 // Returns an array of compute function call arguments (including parameter
    195 // address buffer).
    196 std::vector<llvm::Value*> GetArrayFunctionCallArguments(
    197     tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
    198     llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece name,
    199     llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg,
    200     llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) {
    201   llvm::Value* parameter_addresses_buffer =
    202       llvm_ir::EmitAllocaAtFunctionEntryWithCount(
    203           ir_builder->getInt8PtrTy(),
    204           ir_builder->getInt32(parameter_addresses.size()),
    205           tensorflow::strings::StrCat(name, "_parameter_addresses"),
    206           ir_builder);
    207   for (size_t i = 0; i < parameter_addresses.size(); ++i) {
    208     llvm::Value* parameter_as_i8ptr = ir_builder->CreateBitCast(
    209         parameter_addresses[i], ir_builder->getInt8PtrTy(),
    210         AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i,
    211                                                 "_address_as_i8ptr")));
    212     llvm::Value* slot_in_param_addresses = ir_builder->CreateInBoundsGEP(
    213         parameter_addresses_buffer, {ir_builder->getInt64(i)});
    214     ir_builder->CreateStore(parameter_as_i8ptr, slot_in_param_addresses);
    215   }
    216 
    217   const auto to_int8_ptr = [=](llvm::Value* ptr) {
    218     return ir_builder->CreatePointerCast(ptr, ir_builder->getInt8PtrTy());
    219   };
    220   std::vector<llvm::Value*> arguments{
    221       to_int8_ptr(return_value_buffer), to_int8_ptr(exec_run_options_arg),
    222       parameter_addresses_buffer, temp_buffers_arg};
    223   if (profile_counters_arg != nullptr) {
    224     arguments.push_back(profile_counters_arg);
    225   }
    226   return arguments;
    227 }
    228 
    229 // Emits a call to a runtime fork/join function which dispatches parallel
    230 // calls to 'parallel_function' (and joins threads before returning).
    231 Status EmitCallToParallelForkJoin(
    232     const std::vector<llvm::Value*>& arguments, const Shape& shape,
    233     const std::vector<int64>& dimension_partition_counts,
    234     llvm::IRBuilder<>* ir_builder, llvm::Function* parallel_function,
    235     const string& name) {
    236   llvm::Module* module = ir_builder->GetInsertBlock()->getModule();
    237 
    238   // Build ParallelForkJoin function type.
    239   std::vector<llvm::Type*> compute_function_params =
    240       GetComputeFunctionParams(module, /*num_dynamic_loop_bounds=*/0);
    241   // Number of parallel compute functions.
    242   compute_function_params.push_back(ir_builder->getInt32Ty());
    243   // Array of partitions. There is an array element for each
    244   // partition x partition_dim x 2 (for dimension start and limit).
    245   compute_function_params.push_back(
    246       llvm::Type::getInt64PtrTy(module->getContext()));
    247   // Number of partitioned most-major dimensions in 'shape'.
    248   compute_function_params.push_back(ir_builder->getInt32Ty());
    249   // Function pointer for compute function to be dispatched in parallel.
    250   compute_function_params.push_back(
    251       llvm::Type::getInt8PtrTy(module->getContext()));
    252 
    253   llvm::FunctionType* fork_join_type = llvm::FunctionType::get(
    254       /*Result=*/llvm::Type::getVoidTy(module->getContext()),
    255       /*Params=*/compute_function_params,
    256       /*isVarArg=*/false);
    257 
    258   llvm::Function* fork_join_func =
    259       llvm::cast<llvm::Function>(module->getOrInsertFunction(
    260           runtime::kParallelForkJoinSymbolName, fork_join_type));
    261   fork_join_func->setCallingConv(llvm::CallingConv::C);
    262   fork_join_func->setDoesNotThrow();
    263 
    264   // Add common compute function arguments.
    265   std::vector<llvm::Value*> fork_join_arguments(arguments);
    266 
    267   // Create ShapePartitionIterator to generate all partitions of 'shape'.
    268   ShapePartitionIterator partition_iterator(shape, dimension_partition_counts);
    269   const int64 num_partitions = partition_iterator.GetTotalPartitionCount();
    270   // Add argument specifying the number of parallel partitions.
    271   fork_join_arguments.push_back(ir_builder->getInt32(num_partitions));
    272 
    273   // The number of partitioned most-major dimensions in 'shape'.
    274   const int32 num_partitioned_dims = dimension_partition_counts.size();
    275   // A dimension partition consists of two elements: [start_index, limit_index).
    276   const int32 dim_partition_size = 2;
    277   // Calculate array partition stride.
    278   const int32 array_partition_stride =
    279       num_partitioned_dims * dim_partition_size;
    280   // Calculate the total number of elements in the partition array.
    281   const int32 partition_array_size =
    282       dim_partition_size * num_partitioned_dims * num_partitions;
    283 
    284   // Store dimension partition values as llvm constants in 'partitions'.
    285   // See comments in runtime_fork_join.cc for array layout description.
    286   std::vector<llvm::Constant*> partitions(partition_array_size);
    287   for (int32 i = 0; i < num_partitions; ++i) {
    288     std::vector<std::pair<int64, int64>> dim_partitions =
    289         partition_iterator.GetPartition(i);
    290     CHECK_EQ(num_partitioned_dims, dim_partitions.size());
    291     const int32 partition_index = i * array_partition_stride;
    292     for (int32 j = 0; j < num_partitioned_dims; ++j) {
    293       const std::pair<int64, int64>& dim_partition = dim_partitions[j];
    294       const int32 index = partition_index + j * dim_partition_size;
    295       // Store partition [dim_start, dim_limit) intervals for each dimension.
    296       partitions[index] = ir_builder->getInt64(dim_partition.first);
    297       partitions[index + 1] =
    298           ir_builder->getInt64(dim_partition.first + dim_partition.second);
    299     }
    300   }
    301 
    302   // Create global variable out of dimension partitions in 'partitions'.
    303   llvm::ArrayType* partitions_array_type =
    304       llvm::ArrayType::get(ir_builder->getInt64Ty(), partition_array_size);
    305   llvm::Constant* partitions_array =
    306       llvm::ConstantArray::get(partitions_array_type, partitions);
    307   llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable(
    308       /*M=*/*module,
    309       /*Ty=*/partitions_array_type,
    310       /*isConstant=*/true,
    311       /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
    312       /*Initializer=*/partitions_array,
    313       /*Name=*/
    314       AsStringRef(
    315           tensorflow::strings::StrCat(name, "_parallel_dimension_partitions")));
    316 
    317   // Add argument specifying parallel dimension partitions.
    318   fork_join_arguments.push_back(ir_builder->CreateBitCast(
    319       global_partitions_array,
    320       llvm::Type::getInt64PtrTy(module->getContext())));
    321   // Add argument specifying the number of partitioned most-major dimensions.
    322   fork_join_arguments.push_back(ir_builder->getInt32(num_partitioned_dims));
    323   // Add argument for parallel compute function pointer.
    324   fork_join_arguments.push_back(
    325       ir_builder->CreateBitCast(parallel_function, ir_builder->getInt8PtrTy()));
    326   // Emit call to parallel fork/join.
    327   ir_builder->CreateCall(fork_join_func, fork_join_arguments);
    328 
    329   return Status::OK();
    330 }
    331 
    332 }  // namespace cpu
    333 }  // namespace xla
    334