Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #if GOOGLE_CUDA
     17 
     18 #define EIGEN_USE_GPU
     19 
     20 #include "tensorflow/core/kernels/random_op.h"
     21 
     22 #include <assert.h>
     23 #include <stdio.h>
     24 
     25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     26 #include "tensorflow/core/framework/tensor_types.h"
     27 #include "tensorflow/core/lib/random/philox_random.h"
     28 #include "tensorflow/core/lib/random/random_distributions.h"
     29 #include "tensorflow/core/util/cuda_kernel_helper.h"
     30 
     31 namespace tensorflow {
     32 
     33 class OpKernelContext;
     34 
     35 namespace functor {
     36 
     37 typedef Eigen::GpuDevice GPUDevice;
     38 
     39 template <class Distribution, bool VariableSamplesPerOutput>
     40 struct FillPhiloxRandomKernel;
     41 
     42 template <typename T, int ElementCount>
     43 class SampleCopier {
     44  public:
     45   inline __device__ void operator()(
     46       T* buf, const tensorflow::random::Array<T, ElementCount>& array) const {
     47 #pragma unroll
     48     for (int i = 0; i < ElementCount; i++) {
     49       buf[i] = array[i];
     50     }
     51   }
     52 };
     53 
     54 template <>
     55 class SampleCopier<float, 4> {
     56  public:
     57   // Copies the elements from the array to buf. buf must be 128-bit aligned,
     58   // which is true for tensor data, and all offsets that are a multiple of the
     59   // vector size (because the vectors are 128 bits long).
     60   inline __device__ void operator()(
     61       float* buf, const tensorflow::random::Array<float, 4>& array) const {
     62     // NOTE(ringwalt): It's not safe to cast &array[0] to a float4, because they
     63     // have 32-bit alignment vs 128-bit alignment. There seems to be no
     64     // performance loss when assigning each element to a vector.
     65     float4 vec;
     66     vec.x = array[0];
     67     vec.y = array[1];
     68     vec.z = array[2];
     69     vec.w = array[3];
     70     float4* buf_vector = reinterpret_cast<float4*>(buf);
     71     *buf_vector = vec;
     72   }
     73 };
     74 
     75 template <>
     76 class SampleCopier<int32, 4> {
     77  public:
     78   // Copies the elements from the array to buf. buf must be 128-bit aligned,
     79   // which is true for tensor data, and all offsets that are a multiple of the
     80   // vector size (because the vectors are 128 bits long).
     81   inline __device__ void operator()(
     82       int32* buf, const tensorflow::random::Array<int32, 4>& array) const {
     83     int4 vec;
     84     vec.x = array[0];
     85     vec.y = array[1];
     86     vec.z = array[2];
     87     vec.w = array[3];
     88     int4* buf_vector = reinterpret_cast<int4*>(buf);
     89     *buf_vector = vec;
     90   }
     91 };
     92 
     93 template <>
     94 class SampleCopier<double, 2> {
     95  public:
     96   // Copies the elements from the array to buf. buf must be 128-bit aligned,
     97   // which is true for tensor data, and all offsets that are a multiple of the
     98   // vector size (because the vectors are 128 bits long).
     99   inline __device__ void operator()(
    100       double* buf, const tensorflow::random::Array<double, 2>& array) const {
    101     double2 vec;
    102     vec.x = array[0];
    103     vec.y = array[1];
    104     double2* buf_vector = reinterpret_cast<double2*>(buf);
    105     *buf_vector = vec;
    106   }
    107 };
    108 
    109 template <>
    110 class SampleCopier<int64, 2> {
    111  public:
    112   // Copies the elements from the array to buf. buf must be 128-bit aligned,
    113   // which is true for tensor data, and all offsets that are a multiple of the
    114   // vector size (because the vectors are 128 bits long).
    115   inline __device__ void operator()(
    116       int64* buf, const tensorflow::random::Array<int64, 2>& array) const {
    117     longlong2 vec;
    118     vec.x = array[0];
    119     vec.y = array[1];
    120     longlong2* buf_vector = reinterpret_cast<longlong2*>(buf);
    121     *buf_vector = vec;
    122   }
    123 };
    124 
    125 // A cuda kernel to fill the data with random numbers from the specified
    126 // distribution. Each output takes a fixed number of samples.
    127 template <class Distribution>
    128 struct FillPhiloxRandomKernel<Distribution, false> {
    129   typedef typename Distribution::ResultElementType T;
    130   PHILOX_DEVICE_FUNC void Run(random::PhiloxRandom gen, T* data, int64 size,
    131                               Distribution dist) {
    132     const int kGroupSize = Distribution::kResultElementCount;
    133 
    134     const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x;
    135     const int32 total_thread_count = gridDim.x * blockDim.x;
    136     int32 offset = thread_id * kGroupSize;
    137     gen.Skip(thread_id);
    138 
    139     const SampleCopier<T, kGroupSize> copier;
    140     while (offset + kGroupSize <= size) {
    141       const typename Distribution::ResultType samples = dist(&gen);
    142       copier(&data[offset], samples);
    143 
    144       offset += total_thread_count * kGroupSize;
    145       gen.Skip(total_thread_count - 1);
    146     }
    147 
    148     typename Distribution::ResultType samples = dist(&gen);
    149     for (int i = 0; i < kGroupSize; ++i) {
    150       if (offset >= size) {
    151         return;
    152       }
    153       data[offset] = samples[i];
    154       ++offset;
    155     }
    156   }
    157 };
    158 
    159 // A cuda kernel to fill the data with random numbers from the specified
    160 // distribution. Each output takes a variable number of samples.
    161 template <class Distribution>
    162 struct FillPhiloxRandomKernel<Distribution, true> {
    163   typedef typename Distribution::ResultElementType T;
    164   PHILOX_DEVICE_FUNC void Run(const random::PhiloxRandom& base_gen, T* data,
    165                               int64 size, Distribution dist) {
    166     using random::PhiloxRandom;
    167     using random::SingleSampleAdapter;
    168 
    169     const int kReservedSamplesPerOutput = 256;
    170     const int kGroupSize = Distribution::kResultElementCount;
    171     const int kGeneratorSkipPerOutputGroup = kGroupSize *
    172                                              kReservedSamplesPerOutput /
    173                                              PhiloxRandom::kResultElementCount;
    174 
    175     const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x;
    176     const int32 total_thread_count = gridDim.x * blockDim.x;
    177     int64 group_index = thread_id;
    178     int64 offset = group_index * kGroupSize;
    179 
    180     while (offset < size) {
    181       // Since each output takes a variable number of samples, we need to
    182       // realign the generator to the beginning for the current output group
    183       PhiloxRandom gen = base_gen;
    184       gen.Skip(group_index * kGeneratorSkipPerOutputGroup);
    185       SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
    186 
    187       typename Distribution::ResultType samples = dist(&single_samples);
    188 
    189       for (int i = 0; i < kGroupSize; ++i) {
    190         if (offset >= size) {
    191           return;
    192         }
    193         data[offset] = samples[i];
    194         ++offset;
    195       }
    196 
    197       offset += (total_thread_count - 1) * kGroupSize;
    198       group_index += total_thread_count;
    199     }
    200   }
    201 };
    202 
    203 // A simple launch pad to call the correct function templates to fill the data
    204 template <class Distribution>
    205 __global__ void __launch_bounds__(1024)
    206     FillPhiloxRandomKernelLaunch(random::PhiloxRandom base_gen,
    207                                  typename Distribution::ResultElementType* data,
    208                                  int64 size, Distribution dist) {
    209   FillPhiloxRandomKernel<Distribution,
    210                          Distribution::kVariableSamplesPerOutput>()
    211       .Run(base_gen, data, size, dist);
    212 }
    213 
    214 // Partial specialization for GPU
    215 template <class Distribution>
    216 void FillPhiloxRandom<GPUDevice, Distribution>::operator()(
    217     OpKernelContext*, const GPUDevice& d, random::PhiloxRandom gen,
    218     typename Distribution::ResultElementType* data, int64 size,
    219     Distribution dist) {
    220   const int32 block_size = d.maxCudaThreadsPerBlock();
    221   const int32 num_blocks =
    222       (d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor()) /
    223       block_size;
    224 
    225   FillPhiloxRandomKernelLaunch<Distribution>
    226       <<<num_blocks, block_size, 0, d.stream()>>>(gen, data, size, dist);
    227 };
    228 
    229 // Explicit instantiation of the GPU distributions functors
    230 // clang-format off
    231 // NVCC cannot handle ">>" properly
    232 template struct FillPhiloxRandom<
    233     GPUDevice, random::UniformDistribution<random::PhiloxRandom, Eigen::half> >;
    234 template struct FillPhiloxRandom<
    235     GPUDevice, random::UniformDistribution<random::PhiloxRandom, float> >;
    236 template struct FillPhiloxRandom<
    237     GPUDevice, random::UniformDistribution<random::PhiloxRandom, double> >;
    238 template struct FillPhiloxRandom<
    239     GPUDevice, random::UniformDistribution<random::PhiloxRandom, int32> >;
    240 template struct FillPhiloxRandom<
    241     GPUDevice, random::UniformDistribution<random::PhiloxRandom, int64> >;
    242 template struct FillPhiloxRandom<
    243     GPUDevice, random::NormalDistribution<random::PhiloxRandom, Eigen::half> >;
    244 template struct FillPhiloxRandom<
    245     GPUDevice, random::NormalDistribution<random::PhiloxRandom, float> >;
    246 template struct FillPhiloxRandom<
    247     GPUDevice, random::NormalDistribution<random::PhiloxRandom, double> >;
    248 template struct FillPhiloxRandom<
    249     GPUDevice, random::TruncatedNormalDistribution<
    250         random::SingleSampleAdapter<random::PhiloxRandom>, Eigen::half> >;
    251 template struct FillPhiloxRandom<
    252     GPUDevice, random::TruncatedNormalDistribution<
    253                    random::SingleSampleAdapter<random::PhiloxRandom>, float> >;
    254 template struct FillPhiloxRandom<
    255     GPUDevice, random::TruncatedNormalDistribution<
    256                    random::SingleSampleAdapter<random::PhiloxRandom>, double> >;
    257 // clang-format on
    258 
    259 }  // namespace functor
    260 }  // namespace tensorflow
    261 
    262 #endif  // GOOGLE_CUDA
    263