Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 // See docs in ../ops/spectral_ops.cc.
     19 
     20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     21 #include "tensorflow/core/framework/op.h"
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/tensor.h"
     24 #include "tensorflow/core/framework/tensor_shape.h"
     25 #include "tensorflow/core/framework/types.h"
     26 #include "tensorflow/core/platform/logging.h"
     27 #include "tensorflow/core/platform/types.h"
     28 #include "tensorflow/core/util/env_var.h"
     29 #include "tensorflow/core/util/work_sharder.h"
     30 
     31 #if GOOGLE_CUDA
     32 #include "tensorflow/core/platform/stream_executor.h"
     33 #endif
     34 
     35 namespace tensorflow {
     36 
     37 class FFTBase : public OpKernel {
     38  public:
     39   explicit FFTBase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
     40 
     41   void Compute(OpKernelContext* ctx) override {
     42     const Tensor& in = ctx->input(0);
     43     const TensorShape& input_shape = in.shape();
     44     const int fft_rank = Rank();
     45     OP_REQUIRES(
     46         ctx, input_shape.dims() >= fft_rank,
     47         errors::InvalidArgument("Input must have rank of at least ", fft_rank,
     48                                 " but got: ", input_shape.DebugString()));
     49 
     50     Tensor* out;
     51     TensorShape output_shape = input_shape;
     52     uint64 fft_shape[3] = {0, 0, 0};
     53 
     54     // In R2C or C2R mode, we use a second input to specify the FFT length
     55     // instead of inferring it from the input shape.
     56     if (IsReal()) {
     57       const Tensor& fft_length = ctx->input(1);
     58       OP_REQUIRES(ctx,
     59                   fft_length.shape().dims() == 1 &&
     60                       fft_length.shape().dim_size(0) == fft_rank,
     61                   errors::InvalidArgument("fft_length must have shape [",
     62                                           fft_rank, "]"));
     63 
     64       auto fft_length_as_vec = fft_length.vec<int32>();
     65       for (int i = 0; i < fft_rank; ++i) {
     66         fft_shape[i] = fft_length_as_vec(i);
     67         // Each input dimension must have length of at least fft_shape[i]. For
     68         // IRFFTs, the inner-most input dimension must have length of at least
     69         // fft_shape[i] / 2 + 1.
     70         bool inner_most = (i == fft_rank - 1);
     71         uint64 min_input_dim_length =
     72             !IsForward() && inner_most ? fft_shape[i] / 2 + 1 : fft_shape[i];
     73         auto input_index = input_shape.dims() - fft_rank + i;
     74         OP_REQUIRES(
     75             ctx,
     76             // We pass through empty tensors, so special case them here.
     77             input_shape.dim_size(input_index) == 0 ||
     78                 input_shape.dim_size(input_index) >= min_input_dim_length,
     79             errors::InvalidArgument(
     80                 "Input dimension ", input_index,
     81                 " must have length of at least ", min_input_dim_length,
     82                 " but got: ", input_shape.dim_size(input_index)));
     83         uint64 dim = IsForward() && inner_most && fft_shape[i] != 0
     84                          ? fft_shape[i] / 2 + 1
     85                          : fft_shape[i];
     86         output_shape.set_dim(output_shape.dims() - fft_rank + i, dim);
     87       }
     88     } else {
     89       for (int i = 0; i < fft_rank; ++i) {
     90         fft_shape[i] =
     91             output_shape.dim_size(output_shape.dims() - fft_rank + i);
     92       }
     93     }
     94 
     95     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &out));
     96     if (input_shape.num_elements() == 0) {
     97       return;
     98     }
     99 
    100     DoFFT(ctx, in, fft_shape, out);
    101   }
    102 
    103  protected:
    104   virtual int Rank() const = 0;
    105   virtual bool IsForward() const = 0;
    106   virtual bool IsReal() const = 0;
    107 
    108   // The function that actually computes the FFT.
    109   virtual void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape,
    110                      Tensor* out) = 0;
    111 };
    112 
    113 typedef Eigen::ThreadPoolDevice CPUDevice;
    114 
    115 template <bool Forward, bool _Real, int FFTRank>
    116 class FFTCPU : public FFTBase {
    117  public:
    118   using FFTBase::FFTBase;
    119 
    120  protected:
    121   int Rank() const override { return FFTRank; }
    122   bool IsForward() const override { return Forward; }
    123   bool IsReal() const override { return _Real; }
    124 
    125   void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape,
    126              Tensor* out) override {
    127     // Create the axes (which are always trailing).
    128     const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
    129     auto device = ctx->eigen_device<CPUDevice>();
    130 
    131     if (!IsReal()) {
    132       auto input = Tensor(in).flat_inner_dims<complex64, FFTRank + 1>();
    133       // Compute the FFT using eigen.
    134       auto output = out->flat_inner_dims<complex64, FFTRank + 1>();
    135       constexpr auto direction =
    136           Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE;
    137       output.device(device) =
    138           input.template fft<Eigen::BothParts, direction>(axes);
    139     } else {
    140       if (IsForward()) {
    141         auto input = Tensor(in).flat_inner_dims<float, FFTRank + 1>();
    142         const auto input_dims = input.dimensions();
    143 
    144         // Slice input to fft_shape on its inner-most dimensions.
    145         Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
    146         input_slice_sizes[0] = input_dims[0];
    147         TensorShape temp_shape{input_dims[0]};
    148         for (int i = 1; i <= FFTRank; ++i) {
    149           input_slice_sizes[i] = fft_shape[i - 1];
    150           temp_shape.AddDim(fft_shape[i - 1]);
    151         }
    152 
    153         auto output = out->flat_inner_dims<complex64, FFTRank + 1>();
    154         const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
    155 
    156         // Compute the full FFT using a temporary tensor.
    157         Tensor temp;
    158         OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<complex64>::v(),
    159                                                temp_shape, &temp));
    160         auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>();
    161         full_fft.device(device) =
    162             input.slice(zero_start_indices, input_slice_sizes)
    163                 .template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
    164 
    165         // Slice away the negative frequency components.
    166         output.device(device) =
    167             full_fft.slice(zero_start_indices, output.dimensions());
    168       } else {
    169         // Reconstruct the full FFT and take the inverse.
    170         auto input = Tensor(in).flat_inner_dims<complex64, FFTRank + 1>();
    171         auto output = out->flat_inner_dims<float, FFTRank + 1>();
    172         const auto input_dims = input.dimensions();
    173 
    174         // Calculate the shape of the temporary tensor for the full FFT and the
    175         // region we will slice from input given fft_shape. We slice input to
    176         // fft_shape on its inner-most dimensions, except the last (which we
    177         // slice to fft_shape[-1] / 2 + 1).
    178         Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
    179         input_slice_sizes[0] = input_dims[0];
    180         TensorShape full_fft_shape;
    181         full_fft_shape.AddDim(input_dims[0]);
    182         for (auto i = 1; i <= FFTRank; i++) {
    183           input_slice_sizes[i] =
    184               i == FFTRank ? fft_shape[i - 1] / 2 + 1 : fft_shape[i - 1];
    185           full_fft_shape.AddDim(fft_shape[i - 1]);
    186         }
    187 
    188         Tensor temp;
    189         OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<complex64>::v(),
    190                                                full_fft_shape, &temp));
    191         auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>();
    192 
    193         // Calculate the starting point and range of the source of
    194         // negative frequency part.
    195         auto neg_sizes = input_slice_sizes;
    196         neg_sizes[FFTRank] =
    197             fft_shape[FFTRank - 1] - input_slice_sizes[FFTRank];
    198         Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_target_indices;
    199         neg_target_indices[FFTRank] = input_slice_sizes[FFTRank];
    200 
    201         const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> start_indices;
    202         Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_start_indices;
    203         neg_start_indices[FFTRank] = 1;
    204 
    205         full_fft.slice(start_indices, input_slice_sizes).device(device) =
    206             input.slice(start_indices, input_slice_sizes);
    207 
    208         // First, conduct IFFTs on outer dimensions. We save computation (and
    209         // avoid touching uninitialized memory) by slicing full_fft to the
    210         // subregion we wrote input to.
    211         if (FFTRank > 1) {
    212           const auto outer_axes =
    213               Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1);
    214           full_fft.slice(start_indices, input_slice_sizes).device(device) =
    215               full_fft.slice(start_indices, input_slice_sizes)
    216                   .template fft<Eigen::BothParts, Eigen::FFT_REVERSE>(
    217                       outer_axes);
    218         }
    219 
    220         // Reconstruct the full FFT by appending reversed and conjugated
    221         // spectrum as the negative frequency part.
    222         Eigen::array<bool, FFTRank + 1> reverse_last_axis;
    223         for (auto i = 0; i <= FFTRank; i++) {
    224           reverse_last_axis[i] = i == FFTRank;
    225         }
    226 
    227         if (neg_sizes[FFTRank] != 0) {
    228           full_fft.slice(neg_target_indices, neg_sizes).device(device) =
    229               full_fft.slice(neg_start_indices, neg_sizes)
    230                   .reverse(reverse_last_axis)
    231                   .conjugate();
    232         }
    233 
    234         auto inner_axis = Eigen::array<int, 1>{FFTRank};
    235         output.device(device) =
    236             full_fft.template fft<Eigen::RealPart, Eigen::FFT_REVERSE>(
    237                 inner_axis);
    238       }
    239     }
    240   }
    241 };
    242 
    243 // Use labels to distinguish between internal and open source versions
    244 // of these kernels.
    245 #ifdef PLATFORM_GOOGLE
    246 #define FFT_LABEL "eigen"
    247 #else
    248 #define FFT_LABEL ""
    249 #endif
    250 
    251 REGISTER_KERNEL_BUILDER(Name("FFT").Device(DEVICE_CPU).Label(FFT_LABEL),
    252                         FFTCPU<true, false, 1>);
    253 REGISTER_KERNEL_BUILDER(Name("IFFT").Device(DEVICE_CPU).Label(FFT_LABEL),
    254                         FFTCPU<false, false, 1>);
    255 REGISTER_KERNEL_BUILDER(Name("FFT2D").Device(DEVICE_CPU).Label(FFT_LABEL),
    256                         FFTCPU<true, false, 2>);
    257 REGISTER_KERNEL_BUILDER(Name("IFFT2D").Device(DEVICE_CPU).Label(FFT_LABEL),
    258                         FFTCPU<false, false, 2>);
    259 REGISTER_KERNEL_BUILDER(Name("FFT3D").Device(DEVICE_CPU).Label(FFT_LABEL),
    260                         FFTCPU<true, false, 3>);
    261 REGISTER_KERNEL_BUILDER(Name("IFFT3D").Device(DEVICE_CPU).Label(FFT_LABEL),
    262                         FFTCPU<false, false, 3>);
    263 
    264 REGISTER_KERNEL_BUILDER(Name("RFFT").Device(DEVICE_CPU).Label(FFT_LABEL),
    265                         FFTCPU<true, true, 1>);
    266 REGISTER_KERNEL_BUILDER(Name("IRFFT").Device(DEVICE_CPU).Label(FFT_LABEL),
    267                         FFTCPU<false, true, 1>);
    268 REGISTER_KERNEL_BUILDER(Name("RFFT2D").Device(DEVICE_CPU).Label(FFT_LABEL),
    269                         FFTCPU<true, true, 2>);
    270 REGISTER_KERNEL_BUILDER(Name("IRFFT2D").Device(DEVICE_CPU).Label(FFT_LABEL),
    271                         FFTCPU<false, true, 2>);
    272 REGISTER_KERNEL_BUILDER(Name("RFFT3D").Device(DEVICE_CPU).Label(FFT_LABEL),
    273                         FFTCPU<true, true, 3>);
    274 REGISTER_KERNEL_BUILDER(Name("IRFFT3D").Device(DEVICE_CPU).Label(FFT_LABEL),
    275                         FFTCPU<false, true, 3>);
    276 
    277 #undef FFT_LABEL
    278 
    279 #if GOOGLE_CUDA
    280 namespace gpu = ::perftools::gputools;
    281 
    282 namespace {
    283 template <typename T>
    284 gpu::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
    285   gpu::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
    286   gpu::DeviceMemory<T> typed(wrapped);
    287   return typed;
    288 }
    289 
    290 template <typename T>
    291 gpu::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory, uint64 size) {
    292   gpu::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory), size * sizeof(T));
    293   gpu::DeviceMemory<T> typed(wrapped);
    294   return typed;
    295 }
    296 
    297 // A class to provide scratch-space allocator for Stream-Executor Cufft
    298 // callback. Tensorflow is responsible for releasing the temporary buffers after
    299 // the kernel finishes.
    300 // TODO(yangzihao): Refactor redundant code in subclasses of ScratchAllocator
    301 // into base class.
    302 class CufftScratchAllocator : public gpu::ScratchAllocator {
    303  public:
    304   ~CufftScratchAllocator() override {}
    305   CufftScratchAllocator(int64 memory_limit, OpKernelContext* context)
    306       : memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
    307   int64 GetMemoryLimitInBytes(gpu::Stream* stream) override {
    308     return memory_limit_;
    309   }
    310   gpu::port::StatusOr<gpu::DeviceMemory<uint8>> AllocateBytes(
    311       gpu::Stream* stream, int64 byte_size) override {
    312     Tensor temporary_memory;
    313     if (byte_size > memory_limit_) {
    314       return gpu::port::StatusOr<gpu::DeviceMemory<uint8>>();
    315     }
    316     AllocationAttributes allocation_attr;
    317     allocation_attr.no_retry_on_failure = true;
    318     Status allocation_status(context_->allocate_temp(
    319         DT_UINT8, TensorShape({byte_size}), &temporary_memory,
    320         AllocatorAttributes(), allocation_attr));
    321     if (!allocation_status.ok()) {
    322       return gpu::port::StatusOr<gpu::DeviceMemory<uint8>>();
    323     }
    324     // Hold the reference of the allocated tensors until the end of the
    325     // allocator.
    326     allocated_tensors_.push_back(temporary_memory);
    327     total_byte_size_ += byte_size;
    328     return gpu::port::StatusOr<gpu::DeviceMemory<uint8>>(
    329         AsDeviceMemory(temporary_memory.flat<uint8>().data(),
    330                        temporary_memory.flat<uint8>().size()));
    331   }
    332   int64 TotalByteSize() { return total_byte_size_; }
    333 
    334  private:
    335   int64 memory_limit_;
    336   int64 total_byte_size_;
    337   OpKernelContext* context_;
    338   std::vector<Tensor> allocated_tensors_;
    339 };
    340 
    341 }  // end namespace
    342 
    343 int64 GetCufftWorkspaceLimit(const string& envvar_in_mb,
    344                              int64 default_value_in_bytes) {
    345   const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
    346   if (workspace_limit_in_mb_str != nullptr &&
    347       strcmp(workspace_limit_in_mb_str, "") != 0) {
    348     int64 scratch_limit_in_mb = -1;
    349     Status status = ReadInt64FromEnvVar(envvar_in_mb, default_value_in_bytes,
    350                                         &scratch_limit_in_mb);
    351     if (!status.ok()) {
    352       LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
    353                    << workspace_limit_in_mb_str;
    354     } else {
    355       return scratch_limit_in_mb * (1 << 20);
    356     }
    357   }
    358   return default_value_in_bytes;
    359 }
    360 
    361 class FFTGPUBase : public FFTBase {
    362  public:
    363   using FFTBase::FFTBase;
    364 
    365  protected:
    366   static int64 CufftScratchSize;
    367   void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape,
    368              Tensor* out) override {
    369     auto* stream = ctx->op_device_context()->stream();
    370     OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
    371 
    372     const TensorShape& input_shape = in.shape();
    373     const TensorShape& output_shape = out->shape();
    374 
    375     const int fft_rank = Rank();
    376     int batch_size = 1;
    377     for (int i = 0; i < input_shape.dims() - fft_rank; ++i) {
    378       batch_size *= input_shape.dim_size(i);
    379     }
    380     uint64 input_embed[3];
    381     const uint64 input_stride = 1;
    382     uint64 input_distance = 1;
    383     uint64 output_embed[3];
    384     const uint64 output_stride = 1;
    385     uint64 output_distance = 1;
    386 
    387     for (int i = 0; i < fft_rank; ++i) {
    388       auto dim_offset = input_shape.dims() - fft_rank + i;
    389       input_embed[i] = input_shape.dim_size(dim_offset);
    390       input_distance *= input_shape.dim_size(dim_offset);
    391       output_embed[i] = output_shape.dim_size(dim_offset);
    392       output_distance *= output_shape.dim_size(dim_offset);
    393     }
    394 
    395     constexpr bool kInPlaceFft = false;
    396     const auto kFftType =
    397         IsReal() ? (IsForward() ? gpu::fft::Type::kR2C : gpu::fft::Type::kC2R)
    398                  : (IsForward() ? gpu::fft::Type::kC2CForward
    399                                 : gpu::fft::Type::kC2CInverse);
    400 
    401     CufftScratchAllocator scratch_allocator(CufftScratchSize, ctx);
    402     auto plan =
    403         stream->parent()->AsFft()->CreateBatchedPlanWithScratchAllocator(
    404             stream, fft_rank, fft_shape, input_embed, input_stride,
    405             input_distance, output_embed, output_stride, output_distance,
    406             kFftType, kInPlaceFft, batch_size, &scratch_allocator);
    407 
    408     if (IsReal()) {
    409       if (IsForward()) {
    410         auto src = AsDeviceMemory<float>(in.flat<float>().data());
    411         auto dst = AsDeviceMemory<complex64>(out->flat<complex64>().data());
    412         OP_REQUIRES(
    413             ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
    414             errors::Internal("fft failed : type=", static_cast<int>(kFftType),
    415                              " in.shape=", input_shape.DebugString()));
    416       } else {
    417         auto src = AsDeviceMemory<complex64>(in.flat<complex64>().data());
    418         auto dst = AsDeviceMemory<float>(out->flat<float>().data());
    419         OP_REQUIRES(
    420             ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
    421             errors::Internal("fft failed : type=", static_cast<int>(kFftType),
    422                              " in.shape=", input_shape.DebugString()));
    423         auto alpha = 1.f / output_distance;
    424         OP_REQUIRES(
    425             ctx,
    426             stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
    427                 .ok(),
    428             errors::Internal("BlasScal failed : in.shape=",
    429                              input_shape.DebugString()));
    430       }
    431     } else {
    432       auto src = AsDeviceMemory<complex64>(in.flat<complex64>().data());
    433       auto dst = AsDeviceMemory<complex64>(out->flat<complex64>().data());
    434       OP_REQUIRES(
    435           ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
    436           errors::Internal("fft failed : type=", static_cast<int>(kFftType),
    437                            " in.shape=", input_shape.DebugString()));
    438       if (!IsForward()) {
    439         auto alpha = complex64(1.f / output_distance);
    440         OP_REQUIRES(
    441             ctx,
    442             stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
    443                 .ok(),
    444             errors::Internal("BlasScal failed : in.shape=",
    445                              input_shape.DebugString()));
    446       }
    447     }
    448   }
    449 };
    450 
    451 int64 FFTGPUBase::CufftScratchSize = GetCufftWorkspaceLimit(
    452     // default value is in bytes despite the name of the environment variable
    453     "TF_CUFFT_WORKSPACE_LIMIT_IN_MB", 1LL << 32  // 4GB
    454 );
    455 
    456 template <bool Forward, bool _Real, int FFTRank>
    457 class FFTGPU : public FFTGPUBase {
    458  public:
    459   static_assert(FFTRank >= 1 && FFTRank <= 3,
    460                 "Only 1D, 2D and 3D FFTs supported.");
    461   explicit FFTGPU(OpKernelConstruction* ctx) : FFTGPUBase(ctx) {}
    462 
    463  protected:
    464   int Rank() const override { return FFTRank; }
    465   bool IsForward() const override { return Forward; }
    466   bool IsReal() const override { return _Real; }
    467 };
    468 
    469 REGISTER_KERNEL_BUILDER(Name("FFT").Device(DEVICE_GPU), FFTGPU<true, false, 1>);
    470 REGISTER_KERNEL_BUILDER(Name("IFFT").Device(DEVICE_GPU),
    471                         FFTGPU<false, false, 1>);
    472 REGISTER_KERNEL_BUILDER(Name("FFT2D").Device(DEVICE_GPU),
    473                         FFTGPU<true, false, 2>);
    474 REGISTER_KERNEL_BUILDER(Name("IFFT2D").Device(DEVICE_GPU),
    475                         FFTGPU<false, false, 2>);
    476 REGISTER_KERNEL_BUILDER(Name("FFT3D").Device(DEVICE_GPU),
    477                         FFTGPU<true, false, 3>);
    478 REGISTER_KERNEL_BUILDER(Name("IFFT3D").Device(DEVICE_GPU),
    479                         FFTGPU<false, false, 3>);
    480 
    481 REGISTER_KERNEL_BUILDER(
    482     Name("RFFT").Device(DEVICE_GPU).HostMemory("fft_length"),
    483     FFTGPU<true, true, 1>);
    484 REGISTER_KERNEL_BUILDER(
    485     Name("IRFFT").Device(DEVICE_GPU).HostMemory("fft_length"),
    486     FFTGPU<false, true, 1>);
    487 REGISTER_KERNEL_BUILDER(
    488     Name("RFFT2D").Device(DEVICE_GPU).HostMemory("fft_length"),
    489     FFTGPU<true, true, 2>);
    490 REGISTER_KERNEL_BUILDER(
    491     Name("IRFFT2D").Device(DEVICE_GPU).HostMemory("fft_length"),
    492     FFTGPU<false, true, 2>);
    493 REGISTER_KERNEL_BUILDER(
    494     Name("RFFT3D").Device(DEVICE_GPU).HostMemory("fft_length"),
    495     FFTGPU<true, true, 3>);
    496 REGISTER_KERNEL_BUILDER(
    497     Name("IRFFT3D").Device(DEVICE_GPU).HostMemory("fft_length"),
    498     FFTGPU<false, true, 3>);
    499 
    500 // Deprecated kernels.
    501 REGISTER_KERNEL_BUILDER(Name("BatchFFT").Device(DEVICE_GPU),
    502                         FFTGPU<true, false, 1>);
    503 REGISTER_KERNEL_BUILDER(Name("BatchIFFT").Device(DEVICE_GPU),
    504                         FFTGPU<false, false, 1>);
    505 REGISTER_KERNEL_BUILDER(Name("BatchFFT2D").Device(DEVICE_GPU),
    506                         FFTGPU<true, false, 2>);
    507 REGISTER_KERNEL_BUILDER(Name("BatchIFFT2D").Device(DEVICE_GPU),
    508                         FFTGPU<false, false, 2>);
    509 REGISTER_KERNEL_BUILDER(Name("BatchFFT3D").Device(DEVICE_GPU),
    510                         FFTGPU<true, false, 3>);
    511 REGISTER_KERNEL_BUILDER(Name("BatchIFFT3D").Device(DEVICE_GPU),
    512                         FFTGPU<false, false, 3>);
    513 #endif  // GOOGLE_CUDA
    514 
    515 }  // end namespace tensorflow
    516