Home | History | Annotate | Download | only in cuda
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/stream_executor/cuda/cuda_fft.h"
     17 
     18 #include <complex>
     19 
     20 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
     21 #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
     22 #include "tensorflow/stream_executor/cuda/cuda_helpers.h"
     23 #include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
     24 #include "tensorflow/stream_executor/cuda/cuda_stream.h"
     25 #include "tensorflow/stream_executor/device_memory.h"
     26 #include "tensorflow/stream_executor/lib/env.h"
     27 #include "tensorflow/stream_executor/lib/initialize.h"
     28 #include "tensorflow/stream_executor/lib/status.h"
     29 #include "tensorflow/stream_executor/platform/logging.h"
     30 #include "tensorflow/stream_executor/platform/port.h"
     31 #include "tensorflow/stream_executor/plugin_registry.h"
     32 #include "tensorflow/stream_executor/stream_executor_internal.h"
     33 
     34 namespace perftools {
     35 namespace gputools {
     36 namespace cuda {
     37 
     38 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuFftPlugin);
     39 
     40 namespace wrap {
     41 
     42 // This macro wraps a global identifier, given by __name, in a callable
     43 // structure that loads the DLL symbol out of the DSO handle in a thread-safe
     44 // manner on first use. This dynamic loading technique is used to avoid DSO
     45 // dependencies on vendor libraries which may or may not be available in the
     46 // deployed binary environment.
     47 #define PERFTOOLS_GPUTOOLS_CUFFT_WRAP(__name)                    \
     48   struct WrapperShim__##__name {                                 \
     49     template <typename... Args>                                  \
     50     cufftResult operator()(CUDAExecutor *parent, Args... args) { \
     51       cuda::ScopedActivateExecutorContext sac{parent};           \
     52       return ::__name(args...);                                  \
     53     }                                                            \
     54   } __name;
     55 
     56 #define CUFFT_ROUTINE_EACH(__macro)                                            \
     57   __macro(cufftDestroy) __macro(cufftSetStream) __macro(cufftPlan1d)           \
     58       __macro(cufftPlan2d) __macro(cufftPlan3d) __macro(cufftPlanMany)         \
     59           __macro(cufftExecD2Z) __macro(cufftExecZ2D) __macro(cufftExecC2C)    \
     60               __macro(cufftExecC2R) __macro(cufftExecZ2Z)                      \
     61                   __macro(cufftExecR2C) __macro(cufftCreate)                   \
     62                       __macro(cufftSetAutoAllocation)                          \
     63                           __macro(cufftSetWorkArea) __macro(cufftGetSize1d)    \
     64                               __macro(cufftMakePlan1d) __macro(cufftGetSize2d) \
     65                                   __macro(cufftMakePlan2d)                     \
     66                                       __macro(cufftGetSize3d)                  \
     67                                           __macro(cufftMakePlan3d)             \
     68                                               __macro(cufftGetSizeMany)        \
     69                                                   __macro(cufftMakePlanMany)
     70 
     71 CUFFT_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUFFT_WRAP)
     72 
     73 }  // namespace wrap
     74 
     75 namespace {
     76 
     77 // A helper function transforming gpu_fft arguments into cuFFT arguments.
     78 cufftType CUDAFftType(fft::Type type) {
     79   switch (type) {
     80     case fft::Type::kC2CForward:
     81     case fft::Type::kC2CInverse:
     82       return CUFFT_C2C;
     83     case fft::Type::kC2R:
     84       return CUFFT_C2R;
     85     case fft::Type::kR2C:
     86       return CUFFT_R2C;
     87     case fft::Type::kZ2ZForward:
     88     case fft::Type::kZ2ZInverse:
     89       return CUFFT_Z2Z;
     90     case fft::Type::kZ2D:
     91       return CUFFT_Z2D;
     92     case fft::Type::kD2Z:
     93       return CUFFT_D2Z;
     94     default:
     95       LOG(FATAL) << "Invalid value of fft::Type.";
     96   }
     97 }
     98 
     99 // Associates the given stream with the given cuFFT plan.
    100 bool SetStream(CUDAExecutor *parent, cufftHandle plan, Stream *stream) {
    101   auto ret = wrap::cufftSetStream(parent, plan, AsCUDAStreamValue(stream));
    102   if (ret != CUFFT_SUCCESS) {
    103     LOG(ERROR) << "failed to run cuFFT routine cufftSetStream: " << ret;
    104     return false;
    105   }
    106   return true;
    107 }
    108 
    109 }  // namespace
    110 
    111 port::Status CUDAFftPlan::Initialize(
    112     CUDAExecutor *parent, Stream *stream, int rank, uint64 *elem_count,
    113     uint64 *input_embed, uint64 input_stride, uint64 input_distance,
    114     uint64 *output_embed, uint64 output_stride, uint64 output_distance,
    115     fft::Type type, int batch_count, ScratchAllocator *scratch_allocator) {
    116   if (IsInitialized()) {
    117     LOG(FATAL) << "Try to repeatedly initialize.";
    118   }
    119   is_initialized_ = true;
    120   int elem_count_[3], input_embed_[3], output_embed_[3];
    121   for (int i = 0; i < rank; ++i) {
    122     elem_count_[i] = elem_count[i];
    123     if (input_embed) {
    124       input_embed_[i] = input_embed[i];
    125     }
    126     if (output_embed) {
    127       output_embed_[i] = output_embed[i];
    128     }
    129   }
    130   parent_ = parent;
    131   fft_type_ = type;
    132   if (batch_count == 1 && input_embed == nullptr && output_embed == nullptr) {
    133     cufftResult_t ret;
    134     if (scratch_allocator == nullptr) {
    135       switch (rank) {
    136         case 1:
    137           // cufftPlan1d
    138           ret = wrap::cufftPlan1d(parent, &plan_, elem_count_[0],
    139                                   CUDAFftType(type), 1 /* = batch */);
    140           if (ret != CUFFT_SUCCESS) {
    141             LOG(ERROR) << "failed to create cuFFT 1d plan:" << ret;
    142             return port::Status{port::error::INTERNAL,
    143                                 "Failed to create cuFFT 1d plan."};
    144           }
    145           return port::Status::OK();
    146         case 2:
    147           // cufftPlan2d
    148           ret = wrap::cufftPlan2d(parent, &plan_, elem_count_[0],
    149                                   elem_count_[1], CUDAFftType(type));
    150           if (ret != CUFFT_SUCCESS) {
    151             LOG(ERROR) << "failed to create cuFFT 2d plan:" << ret;
    152             return port::Status{port::error::INTERNAL,
    153                                 "Failed to create cuFFT 2d plan."};
    154           }
    155           return port::Status::OK();
    156         case 3:
    157           // cufftPlan3d
    158           ret =
    159               wrap::cufftPlan3d(parent, &plan_, elem_count_[0], elem_count_[1],
    160                                 elem_count_[2], CUDAFftType(type));
    161           if (ret != CUFFT_SUCCESS) {
    162             LOG(ERROR) << "failed to create cuFFT 3d plan:" << ret;
    163             return port::Status{port::error::INTERNAL,
    164                                 "Failed to create cuFFT 3d plan."};
    165           }
    166           return port::Status::OK();
    167         default:
    168           LOG(ERROR) << "Invalid rank value for cufftPlan. "
    169                         "Requested 1, 2, or 3, given: "
    170                      << rank;
    171           return port::Status{port::error::INVALID_ARGUMENT,
    172                               "cufftPlan only takes rank 1, 2, or 3."};
    173       }
    174     } else {
    175       ret = wrap::cufftCreate(parent, &plan_);
    176       if (ret != CUFFT_SUCCESS) {
    177         LOG(ERROR) << "failed to create cuFFT plan:" << ret;
    178         return port::Status{port::error::INTERNAL,
    179                             "Failed to create cuFFT plan."};
    180       }
    181       ret = wrap::cufftSetAutoAllocation(parent, plan_, 0);
    182       if (ret != CUFFT_SUCCESS) {
    183         LOG(ERROR) << "failed to set auto allocation for cuFFT plan:" << ret;
    184         return port::Status{port::error::INTERNAL,
    185                             "Failed to set auto allocation for cuFFT plan."};
    186       }
    187       switch (rank) {
    188         case 1:
    189           ret = wrap::cufftMakePlan1d(parent, plan_, elem_count_[0],
    190                                       CUDAFftType(type), /*batch=*/1,
    191                                       &scratch_size_bytes_);
    192           if (ret != CUFFT_SUCCESS) {
    193             LOG(ERROR) << "failed to make cuFFT 1d plan:" << ret;
    194             return port::Status{port::error::INTERNAL,
    195                                 "Failed to make cuFFT 1d plan."};
    196           }
    197           break;
    198         case 2:
    199           ret = wrap::cufftMakePlan2d(parent, plan_, elem_count_[0],
    200                                       elem_count_[1], CUDAFftType(type),
    201                                       &scratch_size_bytes_);
    202           if (ret != CUFFT_SUCCESS) {
    203             LOG(ERROR) << "failed to make cuFFT 2d plan:" << ret;
    204             return port::Status{port::error::INTERNAL,
    205                                 "Failed to make cuFFT 2d plan."};
    206           }
    207           break;
    208         case 3:
    209           ret = wrap::cufftMakePlan3d(parent, plan_, elem_count_[0],
    210                                       elem_count_[1], elem_count_[2],
    211                                       CUDAFftType(type), &scratch_size_bytes_);
    212           if (ret != CUFFT_SUCCESS) {
    213             LOG(ERROR) << "failed to make cuFFT 3d plan:" << ret;
    214             return port::Status{port::error::INTERNAL,
    215                                 "Failed to make cuFFT 3d plan."};
    216           }
    217           break;
    218         default:
    219           LOG(ERROR) << "Invalid rank value for cufftPlan. "
    220                         "Requested 1, 2, or 3, given: "
    221                      << rank;
    222           return port::Status{port::error::INVALID_ARGUMENT,
    223                               "cufftPlan only takes rank 1, 2, or 3."};
    224       }
    225       return UpdateScratchAllocator(stream, scratch_allocator);
    226     }
    227   } else {
    228     // For either multiple batches or rank higher than 3, use cufftPlanMany().
    229     if (scratch_allocator == nullptr) {
    230       auto ret = wrap::cufftPlanMany(
    231           parent, &plan_, rank, elem_count_,
    232           input_embed ? input_embed_ : nullptr, input_stride, input_distance,
    233           output_embed ? output_embed_ : nullptr, output_stride,
    234           output_distance, CUDAFftType(type), batch_count);
    235       if (ret != CUFFT_SUCCESS) {
    236         LOG(ERROR) << "failed to create cuFFT batched plan:" << ret;
    237         return port::Status{port::error::INTERNAL,
    238                             "Failed to create cuFFT batched plan."};
    239       }
    240     } else {
    241       auto ret = wrap::cufftCreate(parent, &plan_);
    242       if (ret != CUFFT_SUCCESS) {
    243         LOG(ERROR) << "failed to create cuFFT batched plan:" << ret;
    244         return port::Status{port::error::INTERNAL,
    245                             "Failed to create cuFFT batched plan."};
    246       }
    247       ret = wrap::cufftSetAutoAllocation(parent, plan_, 0);
    248       if (ret != CUFFT_SUCCESS) {
    249         LOG(ERROR) << "failed to set auto allocation for cuFFT batched plan:"
    250                    << ret;
    251         return port::Status{
    252             port::error::INTERNAL,
    253             "Failed to set auto allocation for cuFFT batched plan."};
    254       }
    255       ret = wrap::cufftMakePlanMany(
    256           parent, plan_, rank, elem_count_,
    257           input_embed ? input_embed_ : nullptr, input_stride, input_distance,
    258           output_embed ? output_embed_ : nullptr, output_stride,
    259           output_distance, CUDAFftType(type), batch_count,
    260           &scratch_size_bytes_);
    261       if (ret != CUFFT_SUCCESS) {
    262         LOG(ERROR) << "failed to make cuFFT batched plan:" << ret;
    263         return port::Status{port::error::INTERNAL,
    264                             "Failed to make cuFFT batched plan."};
    265       }
    266       return UpdateScratchAllocator(stream, scratch_allocator);
    267     }
    268   }
    269   return port::Status::OK();
    270 }
    271 
    272 port::Status CUDAFftPlan::Initialize(CUDAExecutor *parent, Stream *stream,
    273                                      int rank, uint64 *elem_count,
    274                                      fft::Type type,
    275                                      ScratchAllocator *scratch_allocator) {
    276   return Initialize(parent_, stream, rank, elem_count,
    277                     /*input_embed=*/nullptr, /*input_stride=*/0,
    278                     /*input_distance=*/0,
    279                     /*output_embed=*/nullptr, /*output_stride=*/0,
    280                     /*output_distance=*/0, type, 1, scratch_allocator);
    281 }
    282 
    283 port::Status CUDAFftPlan::UpdateScratchAllocator(
    284     Stream *stream, ScratchAllocator *scratch_allocator) {
    285   if (scratch_size_bytes_ != 0) {
    286     auto allocated =
    287         scratch_allocator->AllocateBytes(stream, scratch_size_bytes_);
    288     if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
    289       LOG(ERROR) << "failed to allocate work area.";
    290       return allocated.status();
    291     }
    292   }
    293   // Connect work area with allocated space.
    294   cufftResult_t ret = wrap::cufftSetWorkArea(parent_, plan_, scratch_.opaque());
    295   if (ret != CUFFT_SUCCESS) {
    296     LOG(ERROR) << "failed to set work area for cuFFT plan:" << ret;
    297     return port::Status{port::error::INTERNAL,
    298                         "Failed to set work area for cuFFT plan."};
    299   }
    300   return port::Status::OK();
    301 }
    302 
    303 CUDAFftPlan::~CUDAFftPlan() { wrap::cufftDestroy(parent_, plan_); }
    304 
    305 int CUDAFftPlan::GetFftDirection() const {
    306   if (!IsInitialized()) {
    307     LOG(FATAL) << "Try to get fft direction before initialization.";
    308   } else {
    309     switch (fft_type_) {
    310       case fft::Type::kC2CForward:
    311       case fft::Type::kZ2ZForward:
    312       case fft::Type::kR2C:
    313       case fft::Type::kD2Z:
    314         return CUFFT_FORWARD;
    315       case fft::Type::kC2CInverse:
    316       case fft::Type::kZ2ZInverse:
    317       case fft::Type::kC2R:
    318       case fft::Type::kZ2D:
    319         return CUFFT_INVERSE;
    320       default:
    321         LOG(FATAL) << "Invalid value of fft::Type.";
    322     }
    323   }
    324 }
    325 
    326 std::unique_ptr<fft::Plan> CUDAFft::Create1dPlan(Stream *stream, uint64 num_x,
    327                                                  fft::Type type,
    328                                                  bool in_place_fft) {
    329   std::unique_ptr<CUDAFftPlan> fft_plan_ptr{new CUDAFftPlan()};
    330   uint64 elem_count[1] = {num_x};
    331   port::Status status = fft_plan_ptr->Initialize(
    332       parent_, stream, 1, elem_count, type, /*scratch_allocator=*/nullptr);
    333   // TODO(yangzihao): In the future, send error msg back to TensorFlow
    334   // so it can fail gracefully,
    335   if (!status.ok()) {
    336     LOG(FATAL) << "failed to initialize cufft 1d plan: "
    337                << status.error_message();
    338   }
    339   return std::move(fft_plan_ptr);
    340 }
    341 
    342 std::unique_ptr<fft::Plan> CUDAFft::Create1dPlanWithScratchAllocator(
    343     Stream *stream, uint64 num_x, fft::Type type, bool in_place_fft,
    344     ScratchAllocator *scratch_allocator) {
    345   std::unique_ptr<CUDAFftPlan> fft_plan_ptr{new CUDAFftPlan()};
    346   uint64 elem_count[1] = {num_x};
    347   port::Status status = fft_plan_ptr->Initialize(parent_, stream, 1, elem_count,
    348                                                  type, scratch_allocator);
    349   if (!status.ok()) {
    350     LOG(FATAL)
    351         << "failed to initialize cufft 1d plan with customized allocator: "
    352         << status.error_message();
    353   }
    354   return std::move(fft_plan_ptr);
    355 }
    356 
    357 std::unique_ptr<fft::Plan> CUDAFft::Create2dPlan(Stream *stream, uint64 num_x,
    358                                                  uint64 num_y, fft::Type type,
    359                                                  bool in_place_fft) {
    360   std::unique_ptr<CUDAFftPlan> fft_plan_ptr{new CUDAFftPlan()};
    361   uint64 elem_count[2] = {num_x, num_y};
    362   port::Status status = fft_plan_ptr->Initialize(
    363       parent_, stream, 1, elem_count, type, /*scratch_allocator=*/nullptr);
    364   if (!status.ok()) {
    365     LOG(FATAL) << "failed to initialize cufft 2d plan: "
    366                << status.error_message();
    367   }
    368   return std::move(fft_plan_ptr);
    369 }
    370 
    371 std::unique_ptr<fft::Plan> CUDAFft::Create2dPlanWithScratchAllocator(
    372     Stream *stream, uint64 num_x, uint64 num_y, fft::Type type,
    373     bool in_place_fft, ScratchAllocator *scratch_allocator) {
    374   std::unique_ptr<CUDAFftPlan> fft_plan_ptr{new CUDAFftPlan()};
    375   uint64 elem_count[2] = {num_x, num_y};
    376   port::Status status = fft_plan_ptr->Initialize(parent_, stream, 2, elem_count,
    377                                                  type, scratch_allocator);
    378   if (!status.ok()) {
    379     LOG(FATAL)
    380         << "failed to initialize cufft 2d plan with customized allocator: "
    381         << status.error_message();
    382   }
    383   return std::move(fft_plan_ptr);
    384 }
    385 
    386 std::unique_ptr<fft::Plan> CUDAFft::Create3dPlan(Stream *stream, uint64 num_x,
    387                                                  uint64 num_y, uint64 num_z,
    388                                                  fft::Type type,
    389                                                  bool in_place_fft) {
    390   std::unique_ptr<CUDAFftPlan> fft_plan_ptr{new CUDAFftPlan()};
    391   uint64 elem_count[3] = {num_x, num_y, num_z};
    392   port::Status status = fft_plan_ptr->Initialize(
    393       parent_, stream, 3, elem_count, type, /*scratch_allocator=*/nullptr);
    394   if (!status.ok()) {
    395     LOG(FATAL) << "failed to initialize cufft 3d plan: "
    396                << status.error_message();
    397   }
    398   return std::move(fft_plan_ptr);
    399 }
    400 
    401 std::unique_ptr<fft::Plan> CUDAFft::Create3dPlanWithScratchAllocator(
    402     Stream *stream, uint64 num_x, uint64 num_y, uint64 num_z, fft::Type type,
    403     bool in_place_fft, ScratchAllocator *scratch_allocator) {
    404   std::unique_ptr<CUDAFftPlan> fft_plan_ptr{new CUDAFftPlan()};
    405   uint64 elem_count[3] = {num_x, num_y, num_z};
    406   port::Status status = fft_plan_ptr->Initialize(parent_, stream, 3, elem_count,
    407                                                  type, scratch_allocator);
    408   if (!status.ok()) {
    409     LOG(FATAL)
    410         << "failed to initialize cufft 3d plan with customized allocator: "
    411         << status.error_message();
    412   }
    413   return std::move(fft_plan_ptr);
    414 }
    415 
    416 std::unique_ptr<fft::Plan> CUDAFft::CreateBatchedPlan(
    417     Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed,
    418     uint64 input_stride, uint64 input_distance, uint64 *output_embed,
    419     uint64 output_stride, uint64 output_distance, fft::Type type,
    420     bool in_place_fft, int batch_count) {
    421   std::unique_ptr<CUDAFftPlan> fft_plan_ptr{new CUDAFftPlan()};
    422   port::Status status = fft_plan_ptr->Initialize(
    423       parent_, stream, rank, elem_count, input_embed, input_stride,
    424       input_distance, output_embed, output_stride, output_distance, type,
    425       batch_count, /*scratch_allocator=*/nullptr);
    426   if (!status.ok()) {
    427     LOG(FATAL) << "failed to initialize batched cufft plan: "
    428                << status.error_message();
    429   }
    430 
    431   return std::move(fft_plan_ptr);
    432 }
    433 
    434 std::unique_ptr<fft::Plan> CUDAFft::CreateBatchedPlanWithScratchAllocator(
    435     Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed,
    436     uint64 input_stride, uint64 input_distance, uint64 *output_embed,
    437     uint64 output_stride, uint64 output_distance, fft::Type type,
    438     bool in_place_fft, int batch_count, ScratchAllocator *scratch_allocator) {
    439   std::unique_ptr<CUDAFftPlan> fft_plan_ptr{new CUDAFftPlan()};
    440   port::Status status = fft_plan_ptr->Initialize(
    441       parent_, stream, rank, elem_count, input_embed, input_stride,
    442       input_distance, output_embed, output_stride, output_distance, type,
    443       batch_count, scratch_allocator);
    444   if (!status.ok()) {
    445     LOG(FATAL)
    446         << "failed to initialize batched cufft plan with customized allocator: "
    447         << status.error_message();
    448   }
    449   return std::move(fft_plan_ptr);
    450 }
    451 
    452 void CUDAFft::UpdatePlanWithScratchAllocator(
    453     Stream *stream, fft::Plan *plan, ScratchAllocator *scratch_allocator) {
    454   CUDAFftPlan *cuda_fft_plan = dynamic_cast<CUDAFftPlan *>(plan);
    455   port::Status status =
    456       cuda_fft_plan->UpdateScratchAllocator(stream, scratch_allocator);
    457   if (!status.ok()) {
    458     LOG(FATAL) << "failed to update custom allocator for cufft plan: "
    459                << status.error_message();
    460   }
    461 }
    462 
    463 template <typename FuncT, typename InputT, typename OutputT>
    464 bool CUDAFft::DoFftInternal(Stream *stream, fft::Plan *plan, FuncT cufftExec,
    465                             const DeviceMemory<InputT> &input,
    466                             DeviceMemory<OutputT> *output) {
    467   CUDAFftPlan *cuda_fft_plan = dynamic_cast<CUDAFftPlan *>(plan);
    468   if (cuda_fft_plan == nullptr) {
    469     LOG(ERROR) << "the passed-in plan is not a CUDAFftPlan object.";
    470     return false;
    471   }
    472 
    473   if (!SetStream(parent_, cuda_fft_plan->GetPlan(), stream)) {
    474     return false;
    475   }
    476 
    477   auto ret = cufftExec(parent_, cuda_fft_plan->GetPlan(),
    478                        CUDAComplex(const_cast<InputT *>(CUDAMemory(input))),
    479                        CUDAComplex(CUDAMemoryMutable(output)));
    480 
    481   if (ret != CUFFT_SUCCESS) {
    482     LOG(ERROR) << "failed to run cuFFT routine: " << ret;
    483     return false;
    484   }
    485 
    486   return true;
    487 }
    488 
    489 template <typename FuncT, typename InputT, typename OutputT>
    490 bool CUDAFft::DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan,
    491                                          FuncT cufftExec,
    492                                          const DeviceMemory<InputT> &input,
    493                                          DeviceMemory<OutputT> *output) {
    494   CUDAFftPlan *cuda_fft_plan = dynamic_cast<CUDAFftPlan *>(plan);
    495   if (cuda_fft_plan == nullptr) {
    496     LOG(ERROR) << "the passed-in plan is not a CUDAFftPlan object.";
    497     return false;
    498   }
    499 
    500   if (!SetStream(parent_, cuda_fft_plan->GetPlan(), stream)) {
    501     return false;
    502   }
    503 
    504   auto ret = cufftExec(parent_, cuda_fft_plan->GetPlan(),
    505                        CUDAComplex(const_cast<InputT *>(CUDAMemory(input))),
    506                        CUDAComplex(CUDAMemoryMutable(output)),
    507                        cuda_fft_plan->GetFftDirection());
    508 
    509   if (ret != CUFFT_SUCCESS) {
    510     LOG(ERROR) << "failed to run cuFFT routine: " << ret;
    511     return false;
    512   }
    513 
    514   return true;
    515 }
    516 
    517 #define PERFTOOLS_GPUTOOLS_CUDA_DEFINE_FFT(__type, __fft_type1, __fft_type2, \
    518                                            __fft_type3)                      \
    519   bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan,                       \
    520                       const DeviceMemory<std::complex<__type>> &input,       \
    521                       DeviceMemory<std::complex<__type>> *output) {          \
    522     return DoFftWithDirectionInternal(                                       \
    523         stream, plan, wrap::cufftExec##__fft_type1, input, output);          \
    524   }                                                                          \
    525   bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan,                       \
    526                       const DeviceMemory<__type> &input,                     \
    527                       DeviceMemory<std::complex<__type>> *output) {          \
    528     return DoFftInternal(stream, plan, wrap::cufftExec##__fft_type2, input,  \
    529                          output);                                            \
    530   }                                                                          \
    531   bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan,                       \
    532                       const DeviceMemory<std::complex<__type>> &input,       \
    533                       DeviceMemory<__type> *output) {                        \
    534     return DoFftInternal(stream, plan, wrap::cufftExec##__fft_type3, input,  \
    535                          output);                                            \
    536   }
    537 
    538 PERFTOOLS_GPUTOOLS_CUDA_DEFINE_FFT(float, C2C, R2C, C2R)
    539 PERFTOOLS_GPUTOOLS_CUDA_DEFINE_FFT(double, Z2Z, D2Z, Z2D)
    540 
    541 #undef PERFTOOLS_GPUTOOLS_CUDA_DEFINE_FFT
    542 
    543 }  // namespace cuda
    544 }  // namespace gputools
    545 }  // namespace perftools
    546 
    547 namespace gpu = ::perftools::gputools;
    548 
    549 REGISTER_MODULE_INITIALIZER(register_cufft, {
    550   gpu::port::Status status =
    551       gpu::PluginRegistry::Instance()
    552           ->RegisterFactory<gpu::PluginRegistry::FftFactory>(
    553               gpu::cuda::kCudaPlatformId, gpu::cuda::kCuFftPlugin, "cuFFT",
    554               [](gpu::internal::StreamExecutorInterface
    555                      *parent) -> gpu::fft::FftSupport * {
    556                 gpu::cuda::CUDAExecutor *cuda_executor =
    557                     dynamic_cast<gpu::cuda::CUDAExecutor *>(parent);
    558                 if (cuda_executor == nullptr) {
    559                   LOG(ERROR)
    560                       << "Attempting to initialize an instance of the cuFFT "
    561                       << "support library with a non-CUDA StreamExecutor";
    562                   return nullptr;
    563                 }
    564 
    565                 return new gpu::cuda::CUDAFft(cuda_executor);
    566               });
    567   if (!status.ok()) {
    568     LOG(ERROR) << "Unable to register cuFFT factory: "
    569                << status.error_message();
    570   }
    571 
    572   gpu::PluginRegistry::Instance()->SetDefaultFactory(gpu::cuda::kCudaPlatformId,
    573                                                      gpu::PluginKind::kFft,
    574                                                      gpu::cuda::kCuFftPlugin);
    575 });
    576