Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define USE_EIGEN_TENSOR
     17 #define EIGEN_USE_THREADS
     18 
     19 #include "tensorflow/core/kernels/conv_2d.h"
     20 #include "tensorflow/core/kernels/conv_3d.h"
     21 
     22 #include "tensorflow/core/framework/numeric_op.h"
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/framework/register_types.h"
     25 #include "tensorflow/core/framework/tensor.h"
     26 #include "tensorflow/core/framework/tensor_shape.h"
     27 #include "tensorflow/core/framework/tensor_slice.h"
     28 #include "tensorflow/core/kernels/conv_ops_gpu.h"
     29 #include "tensorflow/core/kernels/ops_util.h"
     30 #include "tensorflow/core/lib/core/errors.h"
     31 #include "tensorflow/core/util/padding.h"
     32 #include "tensorflow/core/util/tensor_format.h"
     33 #include "tensorflow/core/util/use_cudnn.h"
     34 
     35 #if GOOGLE_CUDA
     36 #include "tensorflow/core/platform/stream_executor.h"
     37 using perftools::gputools::dnn::DimIndex;
     38 #endif
     39 
     40 namespace tensorflow {
     41 
     42 typedef Eigen::ThreadPoolDevice CPUDevice;
     43 typedef Eigen::GpuDevice GPUDevice;
     44 
     45 template <typename Device, typename T>
     46 struct LaunchConvOp;
     47 
     48 template <typename T>
     49 struct LaunchConvOp<CPUDevice, T> {
     50   static void launch(OpKernelContext* context, bool cudnn_use_autotune,
     51                      const Tensor& input, const Tensor& filter,
     52                      const std::array<int64, 3>& strides, const Padding padding,
     53                      TensorFormat data_format, Tensor* output) {
     54     OP_REQUIRES(context, data_format == FORMAT_NHWC,
     55                 errors::InvalidArgument("CPU implementation of Conv3D "
     56                                         "currently only supports the NHWC "
     57                                         "tensor format."));
     58     functor::CuboidConvolution<CPUDevice, T>()(
     59         context->eigen_device<CPUDevice>(), output->tensor<T, 5>(),
     60         input.tensor<T, 5>(), filter.tensor<T, 5>(), strides[2], strides[1],
     61         strides[0], BrainPadding2EigenPadding(padding));
     62   }
     63 };
     64 
     65 template <typename Device, typename T>
     66 class Conv3DOp : public BinaryOp<T> {
     67  public:
     68   explicit Conv3DOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
     69     string data_format;
     70     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
     71     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
     72                 errors::InvalidArgument("Invalid data format"));
     73     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
     74     OP_REQUIRES(context, stride_.size() == 5,
     75                 errors::InvalidArgument("Sliding window strides field must "
     76                                         "specify 5 dimensions"));
     77     OP_REQUIRES(
     78         context,
     79         (GetTensorDim(stride_, data_format_, 'N') == 1 &&
     80          GetTensorDim(stride_, data_format_, 'C') == 1),
     81         errors::InvalidArgument("Current implementation does not yet support "
     82                                 "strides in the batch and depth dimensions."));
     83     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
     84     cudnn_use_autotune_ = CudnnUseAutotune();
     85   }
     86 
     87   void Compute(OpKernelContext* context) override {
     88     // Input tensor is of the following dimensions:
     89     // [ batch, in_z, in_y, in_x, in_channels ]
     90     const Tensor& input = context->input(0);
     91 
     92     // Input filter is of the following dimensions:
     93     // [ filter_z, filter_y, filter_x, in_channels, out_channels]
     94     const Tensor& filter = context->input(1);
     95 
     96     // NOTE: The ordering of the spatial dimensions is arbitrary, but has to be
     97     // kept consistent between input/filter/output.
     98     OP_REQUIRES(context, input.dims() == 5,
     99                 errors::InvalidArgument("input must be 5-dimensional"));
    100     OP_REQUIRES(context, filter.dims() == 5,
    101                 errors::InvalidArgument("filter must be 5-dimensional"));
    102 
    103     const int64 in_depth = GetTensorDim(input, data_format_, 'C');
    104     const int64 in_batch = GetTensorDim(input, data_format_, 'N');
    105 
    106     const int64 out_depth = filter.dim_size(4);
    107     OP_REQUIRES(
    108         context, in_depth == filter.dim_size(3),
    109         errors::InvalidArgument("input and filter must have the same depth"));
    110 
    111     // Dimension order for these arrays is: z, y, x.
    112     std::array<int64, 3> input_size = {
    113         {GetTensorDim(input, data_format_, '0'),
    114          GetTensorDim(input, data_format_, '1'),
    115          GetTensorDim(input, data_format_, '2')}};
    116     std::array<int64, 3> filter_size = {
    117         {filter.dim_size(0), filter.dim_size(1), filter.dim_size(2)}};
    118     std::array<int64, 3> strides = {{GetTensorDim(stride_, data_format_, '0'),
    119                                      GetTensorDim(stride_, data_format_, '1'),
    120                                      GetTensorDim(stride_, data_format_, '2')}};
    121     std::array<int64, 3> out, padding;
    122 
    123     OP_REQUIRES_OK(context, Get3dOutputSize(input_size, filter_size, strides,
    124                                             padding_, &out, &padding));
    125     TensorShape out_shape = ShapeFromFormat(
    126         data_format_, in_batch, {{out[0], out[1], out[2]}}, out_depth);
    127     Tensor* output;
    128     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
    129 
    130     // Return early if nothing to do.
    131     if (out_shape.num_elements() == 0) return;
    132 
    133     LaunchConvOp<Device, T>::launch(context, cudnn_use_autotune_, input, filter,
    134                                     strides, padding_, data_format_, output);
    135   }
    136 
    137  private:
    138   std::vector<int32> stride_;
    139   Padding padding_;
    140   TensorFormat data_format_;
    141   bool cudnn_use_autotune_;
    142 };
    143 
    144 #define REGISTER_CPU_KERNEL(T)                                  \
    145   REGISTER_KERNEL_BUILDER(                                      \
    146       Name("Conv3D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    147       Conv3DOp<CPUDevice, T>);
    148 TF_CALL_half(REGISTER_CPU_KERNEL);
    149 TF_CALL_float(REGISTER_CPU_KERNEL);
    150 TF_CALL_double(REGISTER_CPU_KERNEL);
    151 #undef REGISTER_CPU_KERNEL
    152 
    153 #if GOOGLE_CUDA
    154 
    155 // A dummy type to group forward convolution autotune results together.
    156 struct Conv3dAutoTuneGroup {
    157   static string name() { return "Conv3d"; }
    158 };
    159 typedef AutoTuneSingleton<Conv3dAutoTuneGroup, ConvParameters,
    160                           perftools::gputools::dnn::AlgorithmConfig>
    161     AutoTuneConv3d;
    162 
    163 // TODO(mjanusz): Share logic with 2d implementation as much as possible.
    164 template <typename T>
    165 struct LaunchConvOp<GPUDevice, T> {
    166   static void launch(OpKernelContext* ctx, bool cudnn_use_autotune,
    167                      const Tensor& input_param, const Tensor& filter,
    168                      const std::array<int64, 3>& strides, const Padding padding,
    169                      TensorFormat data_format, Tensor* output) {
    170     auto* stream = ctx->op_device_context()->stream();
    171     OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
    172 
    173     Tensor input = input_param;
    174 
    175     const int64 in_batch = GetTensorDim(input, data_format, 'N');
    176     int64 in_planes = GetTensorDim(input, data_format, '0');
    177     int64 in_rows = GetTensorDim(input, data_format, '1');
    178     int64 in_cols = GetTensorDim(input, data_format, '2');
    179     const int64 in_depth = GetTensorDim(input, data_format, 'C');
    180 
    181     const int64 filter_planes = filter.dim_size(0);
    182     const int64 filter_rows = filter.dim_size(1);
    183     const int64 filter_cols = filter.dim_size(2);
    184     const int64 out_depth = filter.dim_size(4);
    185 
    186     int64 pad_planes = 0, pad_rows = 0, pad_cols = 0;
    187     int64 out_planes = GetTensorDim(*output, data_format, '0');
    188     int64 out_rows = GetTensorDim(*output, data_format, '1');
    189     int64 out_cols = GetTensorDim(*output, data_format, '2');
    190 
    191     if (padding == Padding::SAME) {
    192       pad_planes = std::max<int64>(
    193           0, (out_planes - 1) * strides[0] + filter_planes - in_planes);
    194       pad_rows = std::max<int64>(
    195           0, (out_rows - 1) * strides[1] + filter_rows - in_rows);
    196       pad_cols = std::max<int64>(
    197           0, (out_cols - 1) * strides[2] + filter_cols - in_cols);
    198     }
    199 
    200     // NOTE: This only works in NHWC.
    201     if (filter_planes == 1 && filter_rows == 1 && filter_cols == 1 &&
    202         strides[0] == 1 && strides[1] == 1 && strides[2] == 1 &&
    203         data_format == FORMAT_NHWC) {
    204       // 1x1 filter, so call cublas directly.
    205       const uint64 m = in_batch * in_planes * in_rows * in_cols;
    206       const uint64 k = in_depth;
    207       const uint64 n = out_depth;
    208 
    209       auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
    210                                   input.template flat<T>().size());
    211       auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
    212                                   filter.template flat<T>().size());
    213       auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
    214                                   output->template flat<T>().size());
    215 
    216       auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
    217       bool blas_launch_status =
    218           stream
    219               ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr,
    220                              n, a_ptr, k, 0.0f, &c_ptr, n)
    221               .ok();
    222       if (!blas_launch_status) {
    223         ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
    224                                         ", n=", n, ", k=", k));
    225       }
    226       return;
    227     } else if (filter_planes == in_planes && filter_rows == in_rows &&
    228                filter_cols == in_cols && padding == Padding::VALID &&
    229                data_format == FORMAT_NHWC) {
    230       // The input data and filter have the same planes/height/width, so call
    231       // cublas directly.
    232       const uint64 m = in_batch;
    233       const uint64 k = in_planes * in_rows * in_cols * in_depth;
    234       const uint64 n = out_depth;
    235 
    236       auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
    237                                   input.template flat<T>().size());
    238       auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
    239                                   filter.template flat<T>().size());
    240       auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
    241                                   output->template flat<T>().size());
    242 
    243       auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
    244       bool blas_launch_status =
    245           stream
    246               ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr,
    247                              n, a_ptr, k, 0.0f, &c_ptr, n)
    248               .ok();
    249       if (!blas_launch_status) {
    250         ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
    251                                         ", n=", n, ", k=", k));
    252       }
    253       return;
    254     }
    255 
    256     if (padding == Padding::SAME) {
    257       const bool rows_odd = (pad_rows % 2 != 0);
    258       const bool cols_odd = (pad_cols % 2 != 0);
    259       const bool planes_odd = (pad_planes % 2 != 0);
    260 
    261       // Necessary because cuDNN only supports symmetric padding.
    262       // TODO(mjanusz): Consider making this optional? This would save some
    263       // overhead and would work as long as an op trained this way is only
    264       // used on GPU.
    265       if (rows_odd || cols_odd || planes_odd) {
    266         const int64 new_in_rows = in_rows + rows_odd;
    267         const int64 new_in_cols = in_cols + cols_odd;
    268         const int64 new_in_planes = in_planes + planes_odd;
    269 
    270         Tensor transformed_input;
    271         TensorShape transformed_shape = ShapeFromFormat(
    272             data_format, in_batch, {{new_in_planes, new_in_rows, new_in_cols}},
    273             in_depth);
    274         OP_REQUIRES_OK(
    275             ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, transformed_shape,
    276                                     &transformed_input));
    277 
    278         functor::PadInput<GPUDevice, T, int, 5>()(
    279             ctx->eigen_device<GPUDevice>(), To32Bit(input_param.tensor<T, 5>()),
    280             {{0, 0, 0}}, {{planes_odd, rows_odd, cols_odd}},
    281             To32Bit(transformed_input.tensor<T, 5>()), data_format);
    282         input = transformed_input;
    283         in_rows = new_in_rows;
    284         in_cols = new_in_cols;
    285         in_planes = new_in_planes;
    286       }
    287     }
    288 
    289     if (data_format == FORMAT_NHWC) {
    290       const TensorShape nchw_shape = ShapeFromFormat(
    291           FORMAT_NCHW, in_batch, {{in_planes, in_rows, in_cols}}, in_depth);
    292       if (in_depth > 1) {
    293         Tensor transformed_input;
    294         OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
    295                                                nchw_shape, &transformed_input));
    296         // input: [b, x, y, z, d]
    297         // t_input: [b, d, x, y, z]
    298         // NCDHW is the only format universally supported by cuDNN.
    299         functor::NHWCToNCHW<GPUDevice, T, 5>()(
    300             ctx->eigen_device<GPUDevice>(),
    301             const_cast<const Tensor&>(input).tensor<T, 5>(),
    302             transformed_input.tensor<T, 5>());
    303         input = transformed_input;
    304       } else {
    305         CHECK(input.CopyFrom(input, nchw_shape));
    306       }
    307     }
    308 
    309     CHECK(pad_rows >= 0 && pad_cols >= 0 && pad_planes >= 0)
    310         << "Negative paddings: (" << pad_rows << ", " << pad_cols << ", "
    311         << pad_planes << ")";
    312     perftools::gputools::dnn::BatchDescriptor input_desc(3);
    313     input_desc.set_count(in_batch)
    314         .set_feature_map_count(in_depth)
    315         .set_spatial_dim(DimIndex::X, in_cols)
    316         .set_spatial_dim(DimIndex::Y, in_rows)
    317         .set_spatial_dim(DimIndex::Z, in_planes)
    318         .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
    319     perftools::gputools::dnn::BatchDescriptor output_desc(3);
    320     output_desc.set_count(in_batch)
    321         .set_spatial_dim(DimIndex::X, out_cols)
    322         .set_spatial_dim(DimIndex::Y, out_rows)
    323         .set_spatial_dim(DimIndex::Z, out_planes)
    324         .set_feature_map_count(out_depth)
    325         .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
    326     perftools::gputools::dnn::FilterDescriptor filter_desc(3);
    327     filter_desc.set_spatial_dim(DimIndex::X, filter_cols)
    328         .set_spatial_dim(DimIndex::Y, filter_rows)
    329         .set_spatial_dim(DimIndex::Z, filter_planes)
    330         .set_input_feature_map_count(in_depth)
    331         .set_output_feature_map_count(out_depth);
    332     perftools::gputools::dnn::ConvolutionDescriptor conv_desc(3);
    333     conv_desc.set_filter_stride(DimIndex::X, strides[2])
    334         .set_filter_stride(DimIndex::Y, strides[1])
    335         .set_filter_stride(DimIndex::Z, strides[0])
    336         .set_zero_padding(DimIndex::X, pad_cols / 2)
    337         .set_zero_padding(DimIndex::Y, pad_rows / 2)
    338         .set_zero_padding(DimIndex::Z, pad_planes / 2);
    339 
    340     Tensor transformed_filter;
    341     OP_REQUIRES_OK(
    342         ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
    343                                 TensorShape({out_depth, in_depth, filter_planes,
    344                                              filter_rows, filter_cols}),
    345                                 &transformed_filter));
    346     // filter: [x, y, z, in, out]
    347     // t_filter: [out, in, x, y, z]
    348     functor::TransformFilter<GPUDevice, T, int, 5>()(
    349         ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 5>()),
    350         To32Bit(transformed_filter.tensor<T, 5>()));
    351 
    352     Tensor transformed_output;
    353     OP_REQUIRES_OK(
    354         ctx, ctx->allocate_temp(
    355                  DataTypeToEnum<T>::value,
    356                  ShapeFromFormat(FORMAT_NCHW, in_batch,
    357                                  {{out_planes, out_rows, out_cols}}, out_depth),
    358                  &transformed_output));
    359 
    360     auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
    361                                     input.template flat<T>().size());
    362     auto filter_ptr =
    363         AsDeviceMemory(transformed_filter.template flat<T>().data(),
    364                        transformed_filter.template flat<T>().size());
    365     auto output_ptr =
    366         AsDeviceMemory(transformed_output.template flat<T>().data(),
    367                        transformed_output.template flat<T>().size());
    368 
    369     static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit(
    370         "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32);  // 4GB by default
    371 
    372     int device_id = stream->parent()->device_ordinal();
    373     DataType dtype = input.dtype();
    374     ConvParameters conv_parameters = {
    375         in_batch,
    376         in_depth,
    377         {{in_planes, in_rows, in_cols}},
    378         out_depth,
    379         {{filter_planes, filter_rows, filter_cols}},
    380         // TODO(yangzihao): Send in arbitrary dilation rates after the dilated
    381         // conv is supported.
    382         /*dilation=*/{{1, 1, 1}},
    383         {{strides[0], strides[1], strides[2]}},
    384         {{pad_planes, pad_rows, pad_cols}},
    385         dtype,
    386         device_id,
    387     };
    388 
    389     using perftools::gputools::dnn::AlgorithmConfig;
    390     using perftools::gputools::dnn::AlgorithmDesc;
    391     using perftools::gputools::dnn::ProfileResult;
    392 
    393     AlgorithmConfig algorithm_config;
    394 
    395     if (cudnn_use_autotune && !AutoTuneConv3d::GetInstance()->Find(
    396                                   conv_parameters, &algorithm_config)) {
    397       std::vector<AlgorithmDesc> algorithms;
    398       CHECK(stream->parent()->GetConvolveAlgorithms(
    399           conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
    400       ProfileResult best_result;
    401       ProfileResult best_result_no_scratch;
    402       for (auto profile_algorithm : algorithms) {
    403         // TODO(zhengxq): profile each algorithm multiple times to better
    404         // accuracy.
    405         CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
    406         ProfileResult profile_result;
    407         bool cudnn_launch_status =
    408             stream
    409                 ->ThenConvolveWithAlgorithm(
    410                     input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
    411                     output_desc, &output_ptr, &scratch_allocator,
    412                     AlgorithmConfig(profile_algorithm), &profile_result)
    413                 .ok();
    414         if (cudnn_launch_status) {
    415           if (profile_result.is_valid()) {
    416             if (profile_result.elapsed_time_in_ms() <
    417                 best_result.elapsed_time_in_ms()) {
    418               best_result = profile_result;
    419             }
    420             if (scratch_allocator.TotalByteSize() == 0 &&
    421                 profile_result.elapsed_time_in_ms() <
    422                     best_result_no_scratch.elapsed_time_in_ms()) {
    423               best_result_no_scratch = profile_result;
    424             }
    425           }
    426         }
    427       }
    428       OP_REQUIRES(ctx,
    429                   best_result.is_valid() || best_result_no_scratch.is_valid(),
    430                   errors::NotFound("No algorithm worked!"));
    431       if (best_result.is_valid()) {
    432         algorithm_config.set_algorithm(best_result.algorithm());
    433       }
    434       if (best_result_no_scratch.is_valid()) {
    435         algorithm_config.set_algorithm_no_scratch(
    436             best_result_no_scratch.algorithm());
    437       }
    438       AutoTuneConv3d::GetInstance()->Insert(conv_parameters, algorithm_config);
    439     }
    440 
    441     CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
    442     bool cudnn_launch_status =
    443         stream
    444             ->ThenConvolveWithAlgorithm(input_desc, input_ptr, filter_desc,
    445                                         filter_ptr, conv_desc, output_desc,
    446                                         &output_ptr, &scratch_allocator,
    447                                         algorithm_config, nullptr)
    448             .ok();
    449 
    450     if (!cudnn_launch_status) {
    451       ctx->SetStatus(errors::Internal(
    452           "cuDNN launch failure : input shape(", input.shape().DebugString(),
    453           ") filter shape(", filter.shape().DebugString(), ")"));
    454     }
    455 
    456     if (data_format == FORMAT_NHWC) {
    457       // t_output: [b, out, x, y, z]
    458       // output: [b, x, y, z, out]
    459       functor::NCHWToNHWC<GPUDevice, T, 5>()(
    460           ctx->eigen_device<GPUDevice>(),
    461           const_cast<const Tensor&>(transformed_output).tensor<T, 5>(),
    462           output->tensor<T, 5>());
    463     } else {
    464       *output = transformed_output;
    465     }
    466   }
    467 };
    468 
    469 // Forward declarations of the functor specializations for GPU.
    470 // This ensures that the custom implementation is used instead of the default
    471 // Eigen one (which is used for CPU).
    472 namespace functor {
    473 #define DECLARE_GPU_SPEC(T)                                           \
    474   template <>                                                         \
    475   void TransformFilter<GPUDevice, T, int, 5>::operator()(             \
    476       const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
    477       typename TTypes<T, 5, int>::Tensor out);                        \
    478   template <>                                                         \
    479   void ReverseTransformFilter<GPUDevice, T, 5>::operator()(           \
    480       const GPUDevice& d, typename TTypes<T, 5>::ConstTensor in,      \
    481       typename TTypes<T, 5>::Tensor out);                             \
    482   template <>                                                         \
    483   void PadInput<GPUDevice, T, int, 5>::operator()(                    \
    484       const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
    485       const std::array<int, 3>& padding_left,                         \
    486       const std::array<int, 3>& padding_right,                        \
    487       typename TTypes<T, 5, int>::Tensor out, TensorFormat format);
    488 
    489 DECLARE_GPU_SPEC(Eigen::half);
    490 DECLARE_GPU_SPEC(float);
    491 #undef DECLARE_GPU_SPEC
    492 
    493 }  // namespace functor
    494 
    495 // Registration of the GPU implementations.
    496 REGISTER_KERNEL_BUILDER(
    497     Name("Conv3D").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
    498     Conv3DOp<GPUDevice, Eigen::half>);
    499 REGISTER_KERNEL_BUILDER(
    500     Name("Conv3D").Device(DEVICE_GPU).TypeConstraint<float>("T"),
    501     Conv3DOp<GPUDevice, float>);
    502 #endif  // GOOGLE_CUDA
    503 
    504 }  // namespace tensorflow
    505