Home | History | Annotate | Download | only in gpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <memory>
     17 #include <vector>
     18 
     19 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h"
     20 
     21 #include "llvm/IR/BasicBlock.h"
     22 #include "llvm/IR/Function.h"
     23 #include "llvm/IR/IRBuilder.h"
     24 #include "llvm/IR/Instructions.h"
     25 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
     26 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
     27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     30 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
     31 #include "tensorflow/compiler/xla/service/name_uniquer.h"
     32 #include "tensorflow/core/lib/core/status.h"
     33 
     34 namespace xla {
     35 namespace gpu {
     36 
     37 IrEmitterNested::IrEmitterNested(const HloModuleConfig& hlo_module_config,
     38                                  const HloComputation& nested_computation,
     39                                  IrEmitterContext* ir_emitter_context)
     40     : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/true) {
     41   std::vector<const HloInstruction*> io_hlos;
     42   emitted_function_ =
     43       EmitBasePointersForNestedComputation(nested_computation, &io_hlos);
     44 }
     45 
     46 llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation(
     47     const HloComputation& nested_computation,
     48     std::vector<const HloInstruction*>* io_hlos) {
     49   std::vector<llvm::Type*> argument_types;
     50   std::vector<int64> argument_dereferenceable_bytes;
     51   for (const HloInstruction* param :
     52        nested_computation.parameter_instructions()) {
     53     io_hlos->push_back(param);
     54     const Shape& param_shape = param->shape();
     55     argument_types.push_back(
     56         llvm_ir::ShapeToIrType(param_shape, module_)->getPointerTo());
     57     int64 param_size =
     58         llvm_ir::ByteSizeOf(param_shape, module_->getDataLayout());
     59     argument_dereferenceable_bytes.push_back(param_size);
     60   }
     61   {
     62     const HloInstruction* root = nested_computation.root_instruction();
     63     io_hlos->push_back(root);
     64     const Shape& root_shape = root->shape();
     65     argument_types.push_back(
     66         llvm_ir::ShapeToIrType(root_shape, module_)->getPointerTo());
     67     int64 root_size = llvm_ir::ByteSizeOf(
     68         root_shape, ir_emitter_context_->llvm_module()->getDataLayout());
     69     argument_dereferenceable_bytes.push_back(root_size);
     70   }
     71   // The base pointer of the memory block for all pre-allocated temp buffers.
     72   argument_types.push_back(ir_builder_.getInt8PtrTy());
     73 
     74   llvm::FunctionType* function_type =
     75       llvm::FunctionType::get(ir_builder_.getVoidTy(), argument_types, false);
     76   llvm::Function* function = llvm::Function::Create(
     77       function_type,                       // The function type.
     78       llvm::GlobalValue::InternalLinkage,  // The linkage type.
     79       llvm_ir::AsStringRef(ir_emitter_context_->name_uniquer()->GetUniqueName(
     80           llvm_ir::SanitizeFunctionName(
     81               nested_computation.name()))),  // The name of the function.
     82       ir_emitter_context_->llvm_module());   // The parent LLVM module.
     83   for (size_t arg_no = 0; arg_no < argument_dereferenceable_bytes.size();
     84        ++arg_no) {
     85     int64 arg_size = argument_dereferenceable_bytes[arg_no];
     86     if (arg_size > 0) {
     87       function->addDereferenceableAttr(arg_no + 1, arg_size);
     88     }
     89   }
     90 
     91   // TODO(b/65380986): Investigate if adding fast math flags for generated
     92   // kernels makes sense.
     93 
     94   llvm::BasicBlock* entry_bb =
     95       llvm::BasicBlock::Create(function->getContext(), "entry", function);
     96   // Emit a "return void" at entry_bb's end, and sets the insert point before
     97   // that return instruction.
     98   ir_builder_.SetInsertPoint(
     99       llvm::ReturnInst::Create(function->getContext(), entry_bb));
    100 
    101   std::vector<const HloInstruction*> non_io_hlos;
    102   for (const auto* hlo : nested_computation.instructions()) {
    103     if (hlo->opcode() != HloOpcode::kParameter &&
    104         hlo != nested_computation.root_instruction()) {
    105       non_io_hlos.push_back(hlo);
    106     }
    107   }
    108   bindings_.EmitBasePointersForHlos(*io_hlos, non_io_hlos);
    109   return function;
    110 }
    111 
    112 Status IrEmitterNested::HandleParameter(HloInstruction* parameter) {
    113   return Status::OK();
    114 }
    115 
    116 Status IrEmitterNested::EmitTargetElementLoop(
    117     const HloInstruction& hlo,
    118     const llvm_ir::ElementGenerator& element_generator) {
    119   return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo),
    120                               &ir_builder_)
    121       .EmitLoop();
    122 }
    123 
    124 }  // namespace gpu
    125 }  // namespace xla
    126