Home | History | Annotate | Download | only in gpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
     17 
     18 #include <string>
     19 
     20 #include "tensorflow/compiler/xla/types.h"
     21 #include "tensorflow/compiler/xla/util.h"
     22 #include "tensorflow/core/lib/strings/strcat.h"
     23 #include "tensorflow/core/lib/strings/stringprintf.h"
     24 #include "tensorflow/core/platform/logging.h"
     25 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
     26 
     27 namespace se = ::perftools::gputools;
     28 
     29 namespace xla {
     30 namespace gpu {
     31 
     32 FftScratchAllocator::FftScratchAllocator(
     33     int device_ordinal, DeviceMemoryAllocator* memory_allocator)
     34     : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
     35 
     36 FftScratchAllocator::~FftScratchAllocator() {
     37   for (auto& allocated_buffer : allocated_buffers_) {
     38     if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer)
     39              .ok()) {
     40       // The program can still continue with failed deallocation.
     41       LOG(ERROR) << "Failed to deallocate the allocated buffer: "
     42                  << allocated_buffer.opaque();
     43     }
     44   }
     45 }
     46 
     47 int64 FftScratchAllocator::GetMemoryLimitInBytes(se::Stream* stream) {
     48   constexpr int64 kFftScratchSize = 1LL << 32;  // 4GB by default.
     49   return kFftScratchSize;
     50 }
     51 
     52 se::port::StatusOr<se::DeviceMemory<uint8>> FftScratchAllocator::AllocateBytes(
     53     se::Stream* stream, int64 byte_size) {
     54   CHECK_GE(byte_size, 0) << "byte_size must be positive.";
     55   if (byte_size > GetMemoryLimitInBytes(stream)) {
     56     return se::port::Status(
     57         se::port::error::RESOURCE_EXHAUSTED,
     58         tensorflow::strings::Printf(
     59             "Allocating %lld bytes exceeds the memory limit of %lld bytes.",
     60             byte_size, GetMemoryLimitInBytes(stream)));
     61   }
     62 
     63   auto status_or_memory =
     64       memory_allocator_->Allocate(device_ordinal_, byte_size,
     65                                   /*retry_on_failure=*/false);
     66   if (!status_or_memory.ok()) {
     67     return tensorflow::errors::ResourceExhausted(
     68         "Failed to allocate %lld bytes on device %d.", byte_size,
     69         device_ordinal_);
     70   }
     71   se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie();
     72   allocated_buffers_.push_back(allocated_buffer);
     73   total_allocated_bytes_ += byte_size;
     74   return se::DeviceMemory<uint8>(allocated_buffer);
     75 }
     76 
     77 namespace {
     78 
     79 se::fft::Type FftTypeToSeType(FftType type) {
     80   switch (type) {
     81     case FftType::FFT:
     82       return se::fft::Type::kC2CForward;
     83     case FftType::IFFT:
     84       return se::fft::Type::kC2CInverse;
     85     case FftType::IRFFT:
     86       return se::fft::Type::kC2R;
     87     case FftType::RFFT:
     88       return se::fft::Type::kR2C;
     89     default:
     90       LOG(FATAL) << "unsupported fft type";
     91   }
     92 }
     93 
     94 string FftTypeToString(se::fft::Type type) {
     95   switch (type) {
     96     case se::fft::Type::kC2CForward:
     97       return "FFT";
     98     case se::fft::Type::kC2CInverse:
     99       return "IFFT";
    100     case se::fft::Type::kC2R:
    101       return "IRFFT";
    102     case se::fft::Type::kR2C:
    103       return "RFFT";
    104     default:
    105       LOG(FATAL) << "unknown fft type";
    106   }
    107 }
    108 
    109 }  // namespace
    110 
    111 FftThunk::FftThunk(FftType fft_type,
    112                    tensorflow::gtl::ArraySlice<int64> fft_length,
    113                    const BufferAllocation::Slice& input_buffer,
    114                    const BufferAllocation::Slice& output_buffer,
    115                    const Shape& input_shape, const Shape& output_shape,
    116                    const HloInstruction* hlo)
    117     : Thunk(Kind::kFft, hlo),
    118       fft_type_(FftTypeToSeType(fft_type)),
    119       fft_length_(fft_length.begin(), fft_length.end()),
    120       scale_factor_(1.0f),
    121       input_buffer_(input_buffer),
    122       output_buffer_(output_buffer),
    123       input_shape_(input_shape),
    124       output_shape_(output_shape) {}
    125 
    126 tensorflow::Status FftThunk::ExecuteOnStream(
    127     const BufferAllocations& buffer_allocations, se::Stream* stream) {
    128   VLOG(3) << "FFT type: " << FftTypeToString(fft_type_);
    129   VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_);
    130   VLOG(3) << "Output shape: "
    131           << ShapeUtil::HumanStringWithLayout(output_shape_);
    132 
    133   FftScratchAllocator scratch_allocator(buffer_allocations.device_ordinal(),
    134                                         buffer_allocations.memory_allocator());
    135 
    136   if (fft_plan_ == nullptr) {
    137     const int64 fft_rank = fft_length_.size();
    138     CHECK_LE(fft_rank, 3);
    139     int batch_size = 1;
    140     for (int i = 0; i < input_shape_.dimensions_size() - fft_rank; ++i) {
    141       batch_size *= input_shape_.dimensions(i);
    142     }
    143     uint64 fft_length[3];
    144     uint64 input_embed[3];
    145     const uint64 input_stride = 1;
    146     uint64 input_distance = 1;
    147     uint64 output_embed[3];
    148     const uint64 output_stride = 1;
    149     uint64 output_distance = 1;
    150 
    151     for (int i = 0; i < fft_rank; ++i) {
    152       auto dim_offset = input_shape_.dimensions_size() - fft_rank + i;
    153       fft_length[i] = static_cast<uint64>(fft_length_[i]);
    154       input_embed[i] = input_shape_.dimensions(dim_offset);
    155       input_distance *= input_shape_.dimensions(dim_offset);
    156       output_embed[i] = output_shape_.dimensions(dim_offset);
    157       output_distance *= output_shape_.dimensions(dim_offset);
    158     }
    159 
    160     constexpr bool kInPlaceFft = false;
    161     fft_plan_ =
    162         stream->parent()->AsFft()->CreateBatchedPlanWithScratchAllocator(
    163             stream, fft_rank, fft_length, input_embed, input_stride,
    164             input_distance, output_embed, output_stride, output_distance,
    165             fft_type_, kInPlaceFft, batch_size, &scratch_allocator);
    166     scale_factor_ = 1.0f / output_distance;
    167   } else {
    168     stream->parent()->AsFft()->UpdatePlanWithScratchAllocator(
    169         stream, fft_plan_.get(), &scratch_allocator);
    170   }
    171 
    172   bool launch_ok;
    173   switch (fft_type_) {
    174     case se::fft::Type::kC2CForward: {
    175       se::DeviceMemory<complex64> input_data(
    176           buffer_allocations.GetDeviceAddress(input_buffer_));
    177       se::DeviceMemory<complex64> output_data(
    178           buffer_allocations.GetDeviceAddress(output_buffer_));
    179       launch_ok =
    180           stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok();
    181       break;
    182     }
    183     case se::fft::Type::kC2CInverse: {
    184       se::DeviceMemory<complex64> input_data(
    185           buffer_allocations.GetDeviceAddress(input_buffer_));
    186       se::DeviceMemory<complex64> output_data(
    187           buffer_allocations.GetDeviceAddress(output_buffer_));
    188       launch_ok =
    189           stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok();
    190       if (launch_ok) {
    191         launch_ok =
    192             stream
    193                 ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape_),
    194                                complex64(scale_factor_), &output_data, 1)
    195                 .ok();
    196       }
    197       break;
    198     }
    199     case se::fft::Type::kR2C: {
    200       se::DeviceMemory<float> input_data(
    201           buffer_allocations.GetDeviceAddress(input_buffer_));
    202       se::DeviceMemory<complex64> output_data(
    203           buffer_allocations.GetDeviceAddress(output_buffer_));
    204       launch_ok =
    205           stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok();
    206       break;
    207     }
    208     case se::fft::Type::kC2R: {
    209       se::DeviceMemory<complex64> input_data(
    210           buffer_allocations.GetDeviceAddress(input_buffer_));
    211       se::DeviceMemory<float> output_data(
    212           buffer_allocations.GetDeviceAddress(output_buffer_));
    213       launch_ok =
    214           stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok();
    215       if (launch_ok) {
    216         launch_ok = stream
    217                         ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape_),
    218                                        scale_factor_, &output_data, 1)
    219                         .ok();
    220       }
    221       break;
    222     }
    223     default:
    224       LOG(FATAL) << "unsupported fft type";
    225   }
    226   if (launch_ok) {
    227     return tensorflow::Status::OK();
    228   }
    229   return InternalError("Unable to launch fft for thunk %p with type %s", this,
    230                        FftTypeToString(fft_type_).c_str());
    231 }
    232 
    233 }  // namespace gpu
    234 }  // namespace xla
    235