Home | History | Annotate | Download | only in llvm_ir
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
     17 
     18 #include <vector>
     19 
     20 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
     21 #include "absl/strings/str_cat.h"
     22 #include "absl/strings/string_view.h"
     23 #include "absl/types/span.h"
     24 #include "llvm/ADT/APInt.h"
     25 #include "llvm/IR/BasicBlock.h"
     26 #include "llvm/IR/Constants.h"
     27 #include "llvm/IR/DerivedTypes.h"
     28 #include "llvm/IR/Instructions.h"
     29 #include "llvm/IR/Value.h"
     30 #include "tensorflow/compiler/xla/primitive_util.h"
     31 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
     32 #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
     33 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
     34 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
     35 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
     36 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
     37 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
     38 #include "tensorflow/compiler/xla/shape_util.h"
     39 #include "tensorflow/compiler/xla/util.h"
     40 #include "tensorflow/core/lib/core/status.h"
     41 #include "tensorflow/core/platform/types.h"
     42 
     43 namespace xla {
     44 namespace llvm_ir {
     45 
     46 namespace {
     47 
     48 // Adds the inner comparison loop body where we compare elements.
     49 Status EmitCompareLoopBody(
     50     int64 iteration_bound, int64 num_values, llvm::Value* element_pair_index,
     51     int64 xor_mask, llvm::Type* index_type,
     52     std::function<llvm::Value*(int64 operand, llvm::Value* index)>
     53         element_address,
     54     std::function<void(int64 operand, llvm::Value* index, llvm::Value* value)>
     55         write_element,
     56     const EmitCallToNestedComputationCallback& emit_compare_callback,
     57     llvm::IRBuilder<>* b, bool needs_bounds_checks = true) {
     58   auto index_typed_constant = [&](int64 value) {
     59     return llvm::ConstantInt::get(index_type, value);
     60   };
     61   // The 'xor_mask' determines which elements are compared against each other.
     62   // Index 'current_keys_index' will be compared with 'current_keys_index' xor
     63   // 'xor_mask'. This means that we will always compare a block of consecutive
     64   // elements against elements from the adjacent block of the same size. When
     65   // 'xor_mask' is a power of 2, it immediately identifies the size of such a
     66   // block. We can also have 'xor_mask' being 2^k - 1 (for some value of k). In
     67   // that case, we essentially flip the last 'k' - 1 bits when computing the
     68   // position of the element to compare to, so the block size is 2^(k - 1).
     69   int64 block_size = xor_mask;
     70   // Check if it is a value 2^k - 1.
     71   if (xor_mask > 1 && (xor_mask & (xor_mask + 1)) == 0) {
     72     block_size = (xor_mask + 1) / 2;
     73   }
     74   auto current_keys_index = element_pair_index;
     75   if (block_size == 1) {
     76     // If the block size is 1, we take every second element and compare it to
     77     // the next one.
     78     current_keys_index =
     79         b->CreateMul(current_keys_index, index_typed_constant(2));
     80   } else if (block_size * 2 < iteration_bound) {
     81     // current_keys_index iterates through the 'left' elements of the element
     82     // pairs to be compared. We first need to compute the comparison block to
     83     // which the element belongs. The block id of that block is index /
     84     // block_size.
     85     auto block_id =
     86         b->CreateUDiv(current_keys_index, index_typed_constant(block_size));
     87     // The index of the 'left' element within its block is simply the remainder
     88     // when dividing by 'block_size'.
     89     auto index_within_block =
     90         b->CreateURem(current_keys_index, index_typed_constant(block_size));
     91     // The first element of the 'left' block of elements that is compared
     92     // against elements from the adjacent 'right' block of elements is
     93     // 'block_id' * (2 * 'block_size').
     94     auto first_element_in_block =
     95         b->CreateMul(block_id, index_typed_constant(2 * block_size));
     96     current_keys_index =
     97         b->CreateAdd(first_element_in_block, index_within_block);
     98   }
     99   auto compare_keys_index =
    100       b->CreateXor(current_keys_index, index_typed_constant(xor_mask));
    101   // current_keys_index < compare_keys_index
    102   llvm::Value* is_smaller_index =
    103       b->CreateICmpSLT(current_keys_index, compare_keys_index);
    104   // compare_keys_index < iteration_bound
    105   llvm::Value* index_is_inbounds = b->CreateICmpSLT(
    106       compare_keys_index, index_typed_constant(iteration_bound));
    107   llvm::Value* do_comparison =
    108       needs_bounds_checks ? b->CreateAnd(is_smaller_index, index_is_inbounds)
    109                           : b->getInt1(true);
    110 
    111   // if (is_smaller_index && index_is_inbounds)
    112   KernelSupportLibrary ksl(b);
    113   return ksl.IfWithStatus("smaller_comparison_index", do_comparison, [&]() {
    114     std::vector<llvm::Value*> values_to_compare;
    115     for (int i = 0; i < num_values; ++i) {
    116       values_to_compare.push_back(element_address(i, compare_keys_index));
    117       values_to_compare.push_back(element_address(i, current_keys_index));
    118     }
    119     llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
    120     llvm::Value* compare_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
    121         llvm_ir::PrimitiveTypeToIrType(PRED, module), "compare_return_buffer",
    122         b);
    123     TF_RETURN_IF_ERROR(
    124         emit_compare_callback(values_to_compare, compare_return_buffer));
    125     llvm::Value* result = b->CreateLoad(compare_return_buffer);
    126 
    127     // Check if the 'compare' function returns true.
    128     llvm::Value* is_smaller_than =
    129         b->CreateICmpNE(result, llvm::ConstantInt::get(result->getType(), 0),
    130                         "boolean_predicate");
    131     ksl.If("is_smaller_than", is_smaller_than, [&]() {
    132       for (int64 i = 0; i < num_values; ++i) {
    133         // Swap the values.
    134         auto value1 = b->CreateLoad(values_to_compare[i * 2]);
    135         auto value2 = b->CreateLoad(values_to_compare[i * 2 + 1]);
    136         write_element(i, current_keys_index, value1);
    137         write_element(i, compare_keys_index, value2);
    138       }
    139     });
    140     return Status::OK();
    141   });
    142 }
    143 
    144 Status EmitTiledCompareLoop(
    145     const IrArray::Index& tiled_keys_index, int64 dimension_to_sort,
    146     int64 dimension_to_sort_bound, absl::Span<const int64> xor_masks,
    147     const std::vector<IrArray>& params,
    148     const std::vector<llvm::Value*>& param_shmem_buffers, int64 tile_size,
    149     const EmitCallToNestedComputationCallback& emit_compare_callback,
    150     llvm::IRBuilder<>* b) {
    151   KernelSupportLibrary ksl(b);
    152   llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic(
    153       llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b);
    154   llvm_ir::AddRangeMetadata(0, tile_size / 2,
    155                             llvm::cast<llvm::Instruction>(thread_id));
    156   thread_id = b->CreateIntCast(thread_id, tiled_keys_index.GetType(),
    157                                /*isSigned=*/true, "thread.id.x");
    158 
    159   auto copy_loop_body =
    160       [&](std::function<void(llvm::Value * cache_index, llvm::Value * index)>
    161               read_or_write) {
    162         auto value_one = tiled_keys_index.GetConstantWithIndexType(1);
    163         auto current_keys_index =
    164             b->CreateShl(tiled_keys_index[dimension_to_sort], value_one);
    165         // We want to copy two adjacent elements. We first check whether the
    166         // first index position is within bounds.
    167         ksl.If(
    168             "smaller_keys_index",
    169             b->CreateICmpSLT(current_keys_index,
    170                              tiled_keys_index.GetConstantWithIndexType(
    171                                  dimension_to_sort_bound)),
    172             [&]() {
    173               auto cache_index = b->CreateShl(thread_id, value_one);
    174               read_or_write(cache_index, current_keys_index);
    175               // Increment to go to the next index position.
    176               current_keys_index = b->CreateAdd(current_keys_index, value_one);
    177               // Here we check whether the next index position is within bounds.
    178               ksl.If("inner_smaller_keys_index",
    179                      b->CreateICmpSLT(current_keys_index,
    180                                       tiled_keys_index.GetConstantWithIndexType(
    181                                           dimension_to_sort_bound)),
    182                      [&]() {
    183                        cache_index = b->CreateAdd(cache_index, value_one);
    184                        read_or_write(cache_index, current_keys_index);
    185                      });
    186             });
    187       };
    188 
    189   // Copy operand tiles from the operand buffers to shared memory.
    190   std::vector<llvm::Value*> keys_multi_index = tiled_keys_index.multidim();
    191   for (int64 i = 0; i < params.size(); ++i) {
    192     copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) {
    193       keys_multi_index[dimension_to_sort] = index;
    194       IrArray::Index keys_index(keys_multi_index, params[i].GetShape(),
    195                                 tiled_keys_index.GetType());
    196       auto value = params[i].EmitReadArrayElement(keys_index, b);
    197       b->CreateStore(value,
    198                      b->CreateGEP(param_shmem_buffers[i],
    199                                   {tiled_keys_index.GetConstantWithIndexType(0),
    200                                    cache_index}));
    201     });
    202   }
    203   // Wait until all reads have happened.
    204   llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, b);
    205 
    206   // Now emit the bodies of the comparison loops.
    207   auto element_address = [&](int64 operand, llvm::Value* index) {
    208     auto shared_memory_address =
    209         b->CreateGEP(param_shmem_buffers[operand],
    210                      {tiled_keys_index.GetConstantWithIndexType(0), index});
    211     auto ptr_type = shared_memory_address->getType();
    212     // We need a generic pointer with address space 0 instead of a pointer to
    213     // shared memory (address space 3) so that we can pass it to the comparison
    214     // computation.
    215     return b->CreateAddrSpaceCast(
    216         shared_memory_address,
    217         llvm::PointerType::get(ptr_type->getPointerElementType(),
    218                                /*AddressSpace=*/0));
    219   };
    220   auto write_element = [&](int64 operand, llvm::Value* index,
    221                            llvm::Value* value) {
    222     b->CreateStore(
    223         value,
    224         b->CreateGEP(param_shmem_buffers[operand],
    225                      {tiled_keys_index.GetConstantWithIndexType(0), index}));
    226   };
    227   for (int64 xor_mask : xor_masks) {
    228     // The index of the element pair to be compared within the tile stored in
    229     // shared memory. We order the element pairs by the element with the smaller
    230     // index.
    231     auto element_pair_index = thread_id;
    232     // If 'dimension_to_sort_bound' is evenly divisible by 'tile_size', we don't
    233     // need any bounds checks.
    234     if (dimension_to_sort_bound % tile_size) {
    235       // Otherwise we need a bounds check for the last tile. The last tile has
    236       // size 'dimension_to_sort_bound' % 'tile_size'.
    237       TF_RETURN_IF_ERROR(ksl.IfWithStatus(
    238           "is_last_tile",
    239           b->CreateICmpUGE(
    240               b->CreateMul(tiled_keys_index[dimension_to_sort],
    241                            tiled_keys_index.GetConstantWithIndexType(2)),
    242               tiled_keys_index.GetConstantWithIndexType(
    243                   RoundDownToNearest(dimension_to_sort_bound, tile_size))),
    244           [&]() {
    245             return EmitCompareLoopBody(
    246                 dimension_to_sort_bound % tile_size, params.size(),
    247                 element_pair_index, xor_mask, tiled_keys_index.GetType(),
    248                 element_address, write_element, emit_compare_callback, b);
    249           },
    250           [&]() {
    251             return EmitCompareLoopBody(
    252                 tile_size, params.size(), element_pair_index, xor_mask,
    253                 tiled_keys_index.GetType(), element_address, write_element,
    254                 emit_compare_callback, b,
    255                 /*needs_bounds_checks=*/false);
    256           }));
    257     } else {
    258       TF_RETURN_IF_ERROR(EmitCompareLoopBody(
    259           tile_size, params.size(), element_pair_index, xor_mask,
    260           tiled_keys_index.GetType(), element_address, write_element,
    261           emit_compare_callback, b,
    262           /*needs_bounds_checks=*/false));
    263     }
    264     // Wait until all comparisons have happened.
    265     llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, b);
    266   }
    267 
    268   // Copy the operand tiles back from shared memory to the operand buffers.
    269   for (int64 i = 0; i < params.size(); ++i) {
    270     copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) {
    271       keys_multi_index[dimension_to_sort] = index;
    272       IrArray::Index keys_index(keys_multi_index, params[i].GetShape(),
    273                                 tiled_keys_index.GetType());
    274       auto value = b->CreateLoad(b->CreateGEP(
    275           param_shmem_buffers[i],
    276           {tiled_keys_index.GetConstantWithIndexType(0), cache_index}));
    277       params[i].EmitWriteArrayElement(keys_index, value, b);
    278     });
    279   }
    280   // We should normally synchronize here to make sure all writes have happened.
    281   // However the very next thing each thread does is reading 2 elements from the
    282   // operand buffer and writing it into the same location in shared memory from
    283   // which it previously copied it to the operand buffer, and we synchronize
    284   // after this has happened. We can be sure that a thread always writes to the
    285   // same location in shared memory because we have exactly tile_size / 2 many
    286   // threads, and the linear index calculated by ParallelLoopEmitter uses
    287   // linear_index = blockIdx.x * blockDim.x + threadIdx.x;
    288   return Status::OK();
    289 }
    290 }  // namespace
    291 
    292 Status EmitSortInPlace(
    293     int64 dimension_to_sort, const std::vector<IrArray>& values_arrays,
    294     absl::string_view name, absl::Span<const int64> xor_masks,
    295     llvm::IRBuilder<>* b, const gpu::LaunchDimensions& launch_dimensions,
    296     int64 num_iterations_in_sort_dim, const int64 tile_size,
    297     const EmitCallToNestedComputationCallback& emit_compare_callback) {
    298   // Iterate through the keys shape in physical order, but skip the dimension to
    299   // sort and make it the innermost loop which is the loop where the comparisons
    300   // happen. In the dimension to sort, if we use tiling, we iterate through it
    301   // in tiles of 64 elements each, so we use another loop that happens within
    302   // one thread to process this tile worth of data (thereby combining several
    303   // comparison stages of the bitonic sort algorithm because they all happen
    304   // within those 64 elements and are therefore independent of the other
    305   // comparisons).
    306 
    307   const Shape& keys_shape = values_arrays[0].GetShape();
    308   int64 rank = keys_shape.rank();
    309   int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
    310   std::vector<int64> dimensions_in_iteration_order(rank);
    311   std::vector<int64> iteration_order_to_logical_order(rank);
    312   int64 dim = 0;
    313   for (int64 dimension : LayoutUtil::MinorToMajor(keys_shape)) {
    314     if (dimension != dimension_to_sort) {
    315       dimensions_in_iteration_order[dim] = keys_shape.dimensions(dimension);
    316       iteration_order_to_logical_order[dim++] = dimension;
    317     }
    318   }
    319   dimensions_in_iteration_order[dim] = num_iterations_in_sort_dim;
    320   iteration_order_to_logical_order[dim] = dimension_to_sort;
    321 
    322   Shape iteration_shape = ShapeUtil::MakeShape(keys_shape.element_type(),
    323                                                dimensions_in_iteration_order);
    324 
    325   // Allocate shared memory for the tiled compare loop.
    326   std::vector<llvm::Value*> param_shmem_buffers(values_arrays.size(), nullptr);
    327   if (xor_masks.size() > 1) {
    328     llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
    329     for (int64 i = 0; i < values_arrays.size(); ++i) {
    330       llvm::Type* tile_type = llvm::ArrayType::get(
    331           llvm_ir::PrimitiveTypeToIrType(
    332               values_arrays[i].GetShape().element_type(), module),
    333           tile_size);
    334       param_shmem_buffers[i] = llvm_ir::AllocateSharedMemoryTile(
    335           module, tile_type, absl::StrCat(name, "_tile_param_", i));
    336     }
    337   }
    338 
    339   auto compare_loop_body_emitter =
    340       [&](const IrArray::Index& tiles_index) -> Status {
    341     // Naive C++ code for the inner compare loop:
    342     //
    343     // for (int64 i = 0; i < dimension_to_sort_bound; ++i) {
    344     //   int64 j = i ^ xor_mask;
    345     //   /* emitted in EmitCompareLoopBody() */
    346     //   if (i < j && j < dimension_to_sort_bound) {
    347     //     int64 min_key = std::min(keys[i], keys[j]);
    348     //     keys[j] = std::max(keys[i], keys[j]);
    349     //     keys[i] = min_key;
    350     //   }
    351     // }
    352     //
    353     // This follows the algorithm described on Wikipedia:
    354     // https://en.wikipedia.org/wiki/Bitonic_sorter
    355     std::vector<llvm::Value*> keys_multi_index(rank);
    356     for (int64 i = 0; i < rank; ++i) {
    357       keys_multi_index[iteration_order_to_logical_order[i]] = tiles_index[i];
    358     }
    359     if (xor_masks.size() > 1) {
    360       IrArray::Index keys_index(keys_multi_index, values_arrays[0].GetShape(),
    361                                 tiles_index.GetType());
    362       TF_RETURN_IF_ERROR(EmitTiledCompareLoop(
    363           keys_index, dimension_to_sort, dimension_to_sort_bound, xor_masks,
    364           values_arrays, param_shmem_buffers, tile_size, emit_compare_callback,
    365           b));
    366     } else {
    367       auto element_address = [&](int64 operand, llvm::Value* index) {
    368         keys_multi_index[dimension_to_sort] = index;
    369         IrArray::Index keys_index(keys_multi_index,
    370                                   values_arrays[operand].GetShape(),
    371                                   tiles_index.GetType());
    372         return values_arrays[operand].EmitArrayElementAddress(keys_index, b);
    373       };
    374       auto write_element = [&](int64 operand, llvm::Value* index,
    375                                llvm::Value* value) {
    376         keys_multi_index[dimension_to_sort] = index;
    377         IrArray::Index keys_index(keys_multi_index,
    378                                   values_arrays[operand].GetShape(),
    379                                   tiles_index.GetType());
    380         values_arrays[operand].EmitWriteArrayElement(keys_index, value, b);
    381       };
    382       TF_RETURN_IF_ERROR(EmitCompareLoopBody(
    383           dimension_to_sort_bound, values_arrays.size(), tiles_index[rank - 1],
    384           xor_masks[0], tiles_index.GetType(), element_address, write_element,
    385           emit_compare_callback, b));
    386     }
    387     return Status::OK();
    388   };
    389   return gpu::ParallelLoopEmitter(compare_loop_body_emitter, iteration_shape,
    390                                   launch_dimensions, b)
    391       .EmitLoop(name);
    392 }
    393 
    394 }  // namespace llvm_ir
    395 }  // namespace xla
    396