Home | History | Annotate | Download | only in llvm_ir
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
     17 
     18 #include <numeric>
     19 #include <vector>
     20 
     21 #include "llvm/IR/Constants.h"
     22 #include "llvm/IR/Function.h"
     23 #include "llvm/IR/Instructions.h"
     24 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
     25 #include "tensorflow/compiler/xla/shape_util.h"
     26 #include "tensorflow/compiler/xla/types.h"
     27 #include "tensorflow/compiler/xla/xla_data.pb.h"
     28 #include "tensorflow/core/lib/strings/strcat.h"
     29 #include "tensorflow/core/lib/strings/stringprintf.h"
     30 #include "tensorflow/core/platform/logging.h"
     31 
     32 namespace xla {
     33 namespace llvm_ir {
     34 
     35 ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix,
     36                  llvm::Value* start_index, llvm::Value* end_index,
     37                  llvm::Value* step, bool prevent_unrolling,
     38                  bool prevent_vectorization)
     39     : prefix_(prefix.ToString()),
     40       suffix_(suffix.ToString()),
     41       start_index_(start_index),
     42       end_index_(end_index),
     43       step_(step),
     44       insert_before_bb_(nullptr),
     45       prevent_unrolling_(prevent_unrolling),
     46       prevent_vectorization_(prevent_vectorization) {}
     47 
     48 /* static */ std::unique_ptr<ForLoop> ForLoop::EmitForLoop(
     49     tensorflow::StringPiece prefix, llvm::Value* start_index,
     50     llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder,
     51     bool prevent_unrolling, bool prevent_vectorization) {
     52   std::unique_ptr<ForLoop> loop(new ForLoop(prefix, /*suffix=*/"", start_index,
     53                                             end_index, step, prevent_unrolling,
     54                                             prevent_vectorization));
     55   loop->Emit(ir_builder);
     56   return loop;
     57 }
     58 
     59 void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) {
     60   // The preheader block is the block the builder is currently emitting
     61   // code into.
     62   preheader_bb_ = ir_builder->GetInsertBlock();
     63 
     64   llvm::BasicBlock::iterator insert_point = ir_builder->GetInsertPoint();
     65   if (insert_point == preheader_bb_->end()) {
     66     // We're emitting the loop at the end of a basic block. Verify there is no
     67     // terminator (eg, branch) in the basic block.
     68     CHECK_EQ(nullptr, preheader_bb_->getTerminator());
     69 
     70     exit_bb_ = CreateLoopBB("loop_exit", ir_builder);
     71   } else {
     72     // We're emitting the loop into the middle of a basic block. splitBasicBlock
     73     // requires that this basic block be well-formed (have a terminator).
     74     CHECK_NE(nullptr, preheader_bb_->getTerminator());
     75 
     76     // Split the preheader to create an exit basic block. The exit basic block
     77     // will contain all instructions at or after insert_point.
     78     exit_bb_ = preheader_bb_->splitBasicBlock(
     79         insert_point, AsStringRef(GetQualifiedName("loop_exit")));
     80 
     81     // splitBasicBlock adds an unconditional branch between the split basic
     82     // blocks. Remove it. An unconditional branch will be added below from the
     83     // preheader to the header.
     84     preheader_bb_->getTerminator()->eraseFromParent();
     85   }
     86   insert_before_bb_ = exit_bb_;
     87 
     88   // Create remaining basic block which form the inside of the loop.
     89   header_bb_ = CreateLoopBB("loop_header", ir_builder);
     90   body_bb_ = CreateLoopBB("loop_body", ir_builder);
     91 
     92   // Function entry basic block.
     93   // Emit alloca for the induction variable. We do this at the entry to the
     94   // basic block to ensure the alloc only executes once per function (we could
     95   // be emitting a nested loop).
     96   llvm::Function* func = preheader_bb_->getParent();
     97   ir_builder->SetInsertPoint(&func->getEntryBlock(),
     98                              func->getEntryBlock().getFirstInsertionPt());
     99   llvm::Value* indvar_address =
    100       ir_builder->CreateAlloca(ir_builder->getInt64Ty(), nullptr,
    101                                AsStringRef(GetQualifiedName("invar_address")));
    102 
    103   // Preheader basic block.
    104   // Initialize induction variable starting index. Create branch to the header.
    105   ir_builder->SetInsertPoint(preheader_bb_);
    106   ir_builder->CreateStore(start_index_, indvar_address);
    107   // The preheader should not have a branch yet.
    108   CHECK_EQ(preheader_bb_->getTerminator(), nullptr);
    109   ir_builder->CreateBr(header_bb_);
    110 
    111   // Header basic block.
    112   // Emit the loop conditional branch. Load and compare indvar with ending
    113   // index and jump to loop exit if equal. Jump to body otherwise.
    114   ir_builder->SetInsertPoint(header_bb_);
    115   indvar_ = ir_builder->CreateLoad(indvar_address,
    116                                    AsStringRef(GetQualifiedName("indvar")));
    117   llvm::Value* exit_cond = ir_builder->CreateICmpUGE(indvar_, end_index_);
    118   ir_builder->CreateCondBr(/*Cond=*/exit_cond,
    119                            /*True=*/exit_bb_, /*False=*/body_bb_);
    120 
    121   // Body basic block.
    122   // Increment indvar, store indvar, and jump to header.
    123   ir_builder->SetInsertPoint(body_bb_);
    124   llvm::Value* step = step_;
    125   llvm::Value* indvar = indvar_;
    126 
    127   llvm::Value* indvar_inc =
    128       ir_builder->CreateAdd(indvar, step, "invar.inc",
    129                             /*HasNUW=*/true, /*HasNSW=*/true);
    130   ir_builder->CreateStore(indvar_inc, indvar_address);
    131   llvm::BranchInst* back_branch = ir_builder->CreateBr(header_bb_);
    132 
    133   std::vector<llvm::Metadata*> loop_metadata = GetLoopMetadata(ir_builder);
    134   if (!loop_metadata.empty()) {
    135     llvm::LLVMContext* ctx = &start_index_->getContext();
    136     auto temp_node = llvm::MDNode::getTemporary(*ctx, llvm::None);
    137     loop_metadata.insert(loop_metadata.begin(), temp_node.get());
    138     auto loop_id = llvm::MDNode::get(*ctx, loop_metadata);
    139     loop_id->replaceOperandWith(0, loop_id);
    140     back_branch->setMetadata(llvm::LLVMContext::MD_loop, loop_id);
    141   }
    142 
    143   // Re-point the IR builder to the loop exit block.
    144   ir_builder->SetInsertPoint(exit_bb_);
    145 }
    146 
    147 std::vector<llvm::Metadata*> ForLoop::GetLoopMetadata(
    148     llvm::IRBuilder<>* ir_builder) {
    149   const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable";
    150   const char* const kLlvmLoopVectorizeMDName = "llvm.loop.vectorize.enable";
    151   llvm::LLVMContext* ctx = &start_index_->getContext();
    152 
    153   std::vector<llvm::Metadata*> result;
    154   if (prevent_unrolling_) {
    155     result.push_back(llvm::MDNode::get(
    156         *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)}));
    157   }
    158 
    159   if (prevent_vectorization_) {
    160     result.push_back(llvm::MDNode::get(
    161         *ctx, {llvm::MDString::get(*ctx, kLlvmLoopVectorizeMDName),
    162                llvm::ConstantAsMetadata::get(ir_builder->getFalse())}));
    163   }
    164 
    165   return result;
    166 }
    167 
    168 string ForLoop::GetQualifiedName(tensorflow::StringPiece name) {
    169   return llvm_ir::IrName(prefix_, llvm_ir::IrName(name, suffix_));
    170 }
    171 
    172 llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name,
    173                                         llvm::IRBuilder<>* ir_builder) {
    174   return CreateBasicBlock(insert_before_bb_, GetQualifiedName(name),
    175                           ir_builder);
    176 }
    177 
    178 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
    179                                               llvm::Value* start_index,
    180                                               llvm::Value* end_index,
    181                                               bool prevent_unrolling,
    182                                               bool prevent_vectorization) {
    183   return AddLoop(suffix, start_index, end_index, ir_builder_->getInt64(1),
    184                  prevent_unrolling, prevent_vectorization);
    185 }
    186 
    187 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
    188                                               llvm::Value* start_index,
    189                                               llvm::Value* end_index,
    190                                               llvm::Value* stride,
    191                                               bool prevent_unrolling,
    192                                               bool prevent_vectorization) {
    193   if (inner_loop_body_bb_ != nullptr) {
    194     // Create this loop inside the previous one.
    195     ir_builder_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt());
    196   }
    197   std::unique_ptr<ForLoop> loop(new ForLoop(
    198       /*prefix=*/name_, suffix, start_index, end_index, stride,
    199       prevent_unrolling, prevent_vectorization));
    200   loop->Emit(ir_builder_);
    201 
    202   if (outer_loop_preheader_bb_ == nullptr) {
    203     outer_loop_preheader_bb_ = loop->GetPreheaderBasicBlock();
    204   }
    205 
    206   if (outer_loop_exit_bb_ == nullptr) {
    207     outer_loop_exit_bb_ = loop->GetExitBasicBlock();
    208   }
    209 
    210   inner_loop_body_bb_ = loop->GetBodyBasicBlock();
    211 
    212   return loop;
    213 }
    214 
    215 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
    216                                               int64 end_index,
    217                                               tensorflow::StringPiece suffix,
    218                                               bool prevent_unrolling,
    219                                               bool prevent_vectorization) {
    220   CHECK_LE(start_index, end_index);
    221   return AddLoop(suffix, ir_builder_->getInt64(start_index),
    222                  ir_builder_->getInt64(end_index), prevent_unrolling,
    223                  prevent_vectorization);
    224 }
    225 
    226 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
    227                                               int64 end_index, int64 stride,
    228                                               tensorflow::StringPiece suffix,
    229                                               bool prevent_unrolling,
    230                                               bool prevent_vectorization) {
    231   CHECK_LE(start_index, end_index);
    232   return AddLoop(suffix, ir_builder_->getInt64(start_index),
    233                  ir_builder_->getInt64(end_index),
    234                  ir_builder_->getInt64(stride), prevent_unrolling,
    235                  prevent_vectorization);
    236 }
    237 
    238 IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape,
    239                                              tensorflow::StringPiece suffix) {
    240   std::vector<int64> dimensions(ShapeUtil::Rank(shape));
    241   std::iota(dimensions.begin(), dimensions.end(), 0);
    242   return AddLoopsForShapeOnDimensions(shape, dimensions, suffix);
    243 }
    244 
    245 IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions(
    246     const Shape& shape, tensorflow::gtl::ArraySlice<int64> dimensions,
    247     tensorflow::StringPiece suffix) {
    248   llvm_ir::IrArray::Index index(shape.dimensions_size(), nullptr);
    249   for (int64 dimension : dimensions) {
    250     std::unique_ptr<llvm_ir::ForLoop> loop = AddLoop(
    251         /*start_index=*/0,
    252         /*end_index=*/shape.dimensions(dimension),
    253         /*suffix=*/
    254         llvm_ir::IrName(suffix, tensorflow::strings::StrCat(dimension)));
    255     index[dimension] = loop->GetIndVarValue();
    256   }
    257   return index;
    258 }
    259 
    260 }  // namespace llvm_ir
    261 }  // namespace xla
    262