Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/linalg_ops.cc.
     17 // TODO(shamanDevel): Enable complex inputs. This will require a specialization
     18 //                    of Gesvd for complex inputs as well as a new kernel
     19 //                    definition to output the singular values as reals
     20 //                    instead of complex values. The current CPU implementation
     21 //                    outputs the singular values as complex values and then
     22 //                    casts them to reals in the python wrapper.
     23 // TODO(rmlarsen/shamanDevel): This could use a bit of cleanup. We don't need to
     24 // pass quite as many raw pointers around. Would also be nice to reduce code
     25 // duplication.
     26 
     27 #if GOOGLE_CUDA
     28 #define EIGEN_USE_GPU
     29 
     30 #include <algorithm>
     31 #include <vector>
     32 
     33 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     34 #include "tensorflow/core/framework/kernel_def_builder.h"
     35 #include "tensorflow/core/framework/op_kernel.h"
     36 #include "tensorflow/core/framework/register_types.h"
     37 #include "tensorflow/core/framework/tensor_shape.h"
     38 #include "tensorflow/core/framework/types.h"
     39 #include "tensorflow/core/kernels/cuda_solvers.h"
     40 #include "tensorflow/core/kernels/linalg_ops_common.h"
     41 #include "tensorflow/core/kernels/transpose_functor.h"
     42 #include "tensorflow/core/lib/core/errors.h"
     43 #include "tensorflow/core/platform/logging.h"
     44 #include "tensorflow/core/platform/stream_executor.h"
     45 #include "tensorflow/core/platform/types.h"
     46 #include "tensorflow/core/util/cuda_kernel_helper.h"
     47 
     48 namespace tensorflow {
     49 
     50 static const char kErrMsg[] =
     51     "Singular Value Decomposition was not successful. The input might not be "
     52     "valid.";
     53 
     54 typedef Eigen::GpuDevice GPUDevice;
     55 
     56 namespace {
     57 // This kernel computes the reduction
     58 // V' = sum_i (M_i * U_i,1 * S_i).
     59 // The result is stored in V[batch] and has the same sign as the
     60 // real value of V (which should be computed)
     61 template <class Scalar>
     62 __global__ void ComputeValueOfVKernel(Cuda2DLaunchConfig config, int64 m,
     63                                       int64 ldu, const Scalar* M,
     64                                       const Scalar* U, const Scalar* S,
     65                                       Scalar* V) {
     66   CUDA_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count.x, X) {
     67     CUDA_AXIS_KERNEL_LOOP(i, config.virtual_thread_count.y, Y) {
     68       Scalar v = M[i + m * batch] * U[ldu * (i + m * batch)] * S[batch];
     69       CudaAtomicAdd(V + batch, v);
     70     }
     71   }
     72 }
     73 
     74 // Extracts the sign of V
     75 // V[i] = V[i]>=0 ? 1 : 0
     76 template <class Scalar>
     77 __global__ void ExtractSignOfVKernel(CudaLaunchConfig config, Scalar* V) {
     78   CUDA_1D_KERNEL_LOOP(i, config.virtual_thread_count) {
     79     V[i] = V[i] >= 0 ? Scalar(1) : Scalar(-1);
     80   }
     81 }
     82 }  // namespace
     83 
     84 // Scalar: The input scalar type (can be complex)
     85 template <class Scalar>
     86 class SvdOpGpu : public AsyncOpKernel {
     87  public:
     88   using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
     89 
     90   explicit SvdOpGpu(OpKernelConstruction* context) : AsyncOpKernel(context) {
     91     OP_REQUIRES_OK(context, context->GetAttr("compute_uv", &compute_uv_));
     92     OP_REQUIRES_OK(context, context->GetAttr("full_matrices", &full_matrices_));
     93   }
     94 
     95   void RunSVD(OpKernelContext* context, DoneCallback done, int64 m, int64 n,
     96               int64 p, int64 batch_size, Scalar* input_ptr,
     97               RealScalar* outputS_ptr, Scalar* outputU_ptr,
     98               Scalar* outputVT_ptr, int* dev_info_ptr, CudaSolver* solver) {
     99     // Save the input matrix
    100     // Needed for the n=1 fix, see below, since SVD destroys the input
    101     Tensor input_copy;
    102     if (compute_uv_ && n == 1) {
    103       OP_REQUIRES_OK_ASYNC(context,
    104                            solver->allocate_scoped_tensor(
    105                                DataTypeToEnum<Scalar>::v(),
    106                                TensorShape({batch_size, m}), &input_copy),
    107                            done);
    108       const GPUDevice& d = context->eigen_device<GPUDevice>();
    109       d.memcpy(input_copy.flat<Scalar>().data(), input_ptr,
    110                batch_size * m * sizeof(Scalar));
    111     }
    112 
    113     for (int64 batch = 0; batch < batch_size; ++batch) {
    114       Scalar* input = input_ptr + batch * m * n;
    115       RealScalar* outputS = outputS_ptr + batch * p;
    116       Scalar* outputU = NULL;
    117       Scalar* outputVT = NULL;
    118       char jobu = 'N';
    119       char jobvt = 'N';
    120 
    121       if (compute_uv_) {
    122         if (full_matrices_) {
    123           outputU = outputU_ptr + batch * m * m;
    124           outputVT = outputVT_ptr + batch * n * n;
    125           jobu = 'A';
    126           jobvt = 'A';
    127         } else {
    128           outputU = outputU_ptr + batch * m * p;
    129           outputVT = outputVT_ptr + batch * n * p;
    130           jobu = 'S';
    131           jobvt = 'S';
    132         }
    133       }
    134 
    135       OP_REQUIRES_OK_ASYNC(
    136           context,
    137           solver->Gesvd(jobu, jobvt, m, n, input, m, outputS, outputU, m,
    138                         outputVT, n, dev_info_ptr + batch),
    139           done);
    140     }
    141 
    142     // This is a bug in cuSolver:
    143     // If n is one, then outputVT only contains zeros instead of ones.
    144     // Hence, I need to fill outputVT manually
    145     // The question is: +1 or -1?
    146     // -> Compute U*S and compare sign against M
    147     // But because S is zero except for the first entry, the multiplication
    148     // simplifies a lot.
    149     // However, what happens if M contains zeros? At these indices, it is
    150     // impossible to determine the value of V.
    151     // -> Compute V for all rows in M to cope for zeros.
    152     // 1. V' = sum_i (M_i * U_i,1 * S_i)
    153     // 2. V = {1, V'>=0, -1, V'<0}
    154     // TODO: what is with complex values?
    155     if (compute_uv_ && n == 1) {
    156       // 1. compute the (batched) sum
    157       const GPUDevice& d = context->eigen_device<GPUDevice>();
    158       d.memset(outputVT_ptr, 0, batch_size * sizeof(Scalar));
    159       Cuda2DLaunchConfig cfg2D = GetCuda2DLaunchConfig(batch_size, m, d);
    160       ComputeValueOfVKernel<<<cfg2D.block_count, cfg2D.thread_per_block, 0,
    161                               d.stream()>>>(
    162           cfg2D, m, full_matrices_ ? m : p, input_copy.flat<Scalar>().data(),
    163           outputU_ptr, outputS_ptr, outputVT_ptr);
    164       // 2. clamp V to -1 or +1
    165       CudaLaunchConfig cfg1D = GetCudaLaunchConfig(batch_size, d);
    166       ExtractSignOfVKernel<<<cfg1D.block_count, cfg1D.thread_per_block, 0,
    167                              d.stream()>>>(cfg1D, outputVT_ptr);
    168     }
    169   }
    170 
    171   void CheckResult(OpKernelContext* context, DoneCallback done,
    172                    const std::vector<DeviceLapackInfo>& dev_info,
    173                    std::unique_ptr<CudaSolver> solver) {
    174     auto info_checker = [context, done](
    175                             const Status& status,
    176                             const std::vector<HostLapackInfo>& /* unused */) {
    177       Status full_status = status;
    178       if (!full_status.ok()) {
    179         full_status.Update(errors::InvalidArgument(kErrMsg));
    180       }
    181       OP_REQUIRES_OK_ASYNC(context, full_status, done);
    182       done();
    183     };
    184 
    185     CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info,
    186                                                     std::move(info_checker));
    187   }
    188 
    189   // The SVD if m >= n
    190   // TODO: can the two cases (MgeqN and MlessN) be simplified,
    191   //   common boilerplate be reduced, or even combined in one method?
    192   void PerformSVD_MgeqN(OpKernelContext* context, DoneCallback done, int64 m,
    193                         int64 n, int64 p, const Tensor& M, Tensor* S, Tensor* U,
    194                         Tensor* V) {
    195     TensorShape shapeRaw = M.shape();
    196     shapeRaw.RemoveLastDims(2);
    197 
    198     // Transpose M, because cuSolver expects it to be column-major
    199     TensorShape input_shape = shapeRaw;
    200     input_shape.AddDim(n);
    201     input_shape.AddDim(m);
    202     Tensor input_copy;
    203     // TODO(rmlarsen): Convert to std::make_unique when available.
    204     std::unique_ptr<CudaSolver> solver(new CudaSolver(context));
    205     OP_REQUIRES_OK_ASYNC(
    206         context,
    207         solver->allocate_scoped_tensor(M.dtype(), input_shape, &input_copy),
    208         done);
    209     auto device = context->eigen_device<GPUDevice>();
    210     OP_REQUIRES_OK_ASYNC(context, DoMatrixTranspose(device, M, &input_copy),
    211                          done);
    212 
    213     // I need to transpose U at the end
    214     // Not V, because cuSolver work column-major
    215     Tensor u_copy;
    216     if (compute_uv_) {
    217       TensorShape u_shape;
    218       if (full_matrices_) {
    219         u_shape = U->shape();
    220       } else {
    221         u_shape = shapeRaw;
    222         u_shape.AddDim(p);
    223         u_shape.AddDim(m);
    224       }
    225       OP_REQUIRES_OK_ASYNC(
    226           context, solver->allocate_scoped_tensor(U->dtype(), u_shape, &u_copy),
    227           done);
    228     }
    229 
    230     // get the pointers to the data
    231     Scalar* input_ptr;
    232     RealScalar* outputS_ptr;
    233     Scalar* outputU_ptr = NULL;
    234     Scalar* outputV_ptr = NULL;
    235     auto input_reshaped = input_copy.template flat_inner_dims<Scalar, 3>();
    236     input_ptr = input_reshaped.data();
    237     outputS_ptr = S->template flat_inner_dims<RealScalar, 2>().data();
    238     if (compute_uv_) {
    239       outputU_ptr = u_copy.template flat_inner_dims<Scalar, 3>().data();
    240       outputV_ptr = V->template flat_inner_dims<Scalar, 3>().data();
    241     }
    242 
    243     // call the SVD
    244     const int64 batch_size = input_reshaped.dimension(0);
    245     std::vector<DeviceLapackInfo> dev_info;
    246     dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "gesvd"));
    247     RunSVD(context, done, m, n, p, batch_size, input_ptr, outputS_ptr,
    248            outputU_ptr, outputV_ptr, dev_info.back().mutable_data(),
    249            solver.get());
    250 
    251     // Transpose U
    252     if (compute_uv_) {
    253       OP_REQUIRES_OK_ASYNC(context, DoMatrixTranspose(device, u_copy, U), done);
    254     }
    255 
    256     // now check if the SVD operation succeeded or not
    257     CheckResult(context, std::move(done), dev_info, std::move(solver));
    258   }
    259 
    260   // The SVD if m < n
    261   void PerformSVD_MlessN(OpKernelContext* context, DoneCallback done, int64 m,
    262                          int64 n, int64 p, const Tensor& M, Tensor* S,
    263                          Tensor* U, Tensor* V) {
    264     // Perform the SVD on M'
    265 
    266     // Reuse the input buffer or make a copy for the SVD depending on whether
    267     // this op owns the input buffer exclusively. This is needed because the
    268     // SVD modifies the input
    269     // TODO(rmlarsen): Convert to std::make_unique when available.
    270     std::unique_ptr<CudaSolver> solver(new CudaSolver(context));
    271     Tensor input_copy;
    272     OP_REQUIRES_OK_ASYNC(
    273         context,
    274         solver->forward_input_or_allocate_scoped_tensor(
    275             {0}, DataTypeToEnum<Scalar>::value, M.shape(), &input_copy),
    276         done);
    277 
    278     if (!M.SharesBufferWith(input_copy)) {
    279       const GPUDevice& d = context->eigen_device<GPUDevice>();
    280       d.memcpy(input_copy.flat<Scalar>().data(), M.flat<Scalar>().data(),
    281                M.NumElements() * sizeof(Scalar));
    282     }
    283 
    284     // I need to transpose V at the end
    285     Tensor v_copy;
    286     if (compute_uv_) {
    287       TensorShape v_shape;
    288       if (full_matrices_) {
    289         v_shape = V->shape();
    290       } else {
    291         TensorShape shapeRaw = M.shape();
    292         shapeRaw.RemoveLastDims(2);
    293         v_shape = shapeRaw;
    294         v_shape.AddDim(p);
    295         v_shape.AddDim(n);
    296       }
    297       OP_REQUIRES_OK_ASYNC(
    298           context, solver->allocate_scoped_tensor(V->dtype(), v_shape, &v_copy),
    299           done);
    300     }
    301 
    302     // get the pointers to the data
    303     Scalar* input_ptr;
    304     RealScalar* outputS_ptr;
    305     Scalar* outputU_ptr = NULL;
    306     Scalar* outputV_ptr = NULL;
    307     auto input_reshaped = input_copy.template flat_inner_dims<Scalar, 3>();
    308     input_ptr = input_reshaped.data();
    309     outputS_ptr = S->template flat_inner_dims<RealScalar, 2>().data();
    310     if (compute_uv_) {
    311       // Note that U and V are flipped
    312       outputU_ptr = v_copy.template flat_inner_dims<Scalar, 3>().data();
    313       outputV_ptr = U->template flat_inner_dims<Scalar, 3>().data();
    314     }
    315 
    316     // call the SVD
    317     const int64 batch_size = input_reshaped.dimension(0);
    318     std::vector<DeviceLapackInfo> dev_info;
    319     dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "gesvd"));
    320     // Note that m and n are flipped
    321     RunSVD(context, done, n, m, p, batch_size, input_ptr, outputS_ptr,
    322            outputU_ptr, outputV_ptr, dev_info.back().mutable_data(),
    323            solver.get());
    324 
    325     // Transpose V
    326     if (compute_uv_) {
    327       auto device = context->eigen_device<GPUDevice>();
    328       OP_REQUIRES_OK_ASYNC(context, DoMatrixTranspose(device, v_copy, V), done);
    329     }
    330 
    331     // now check if the SVD operation succeeded or not
    332     CheckResult(context, std::move(done), dev_info, std::move(solver));
    333   }
    334 
    335   void ComputeAsync(OpKernelContext* context, DoneCallback done) final {
    336     const Tensor& input = context->input(0);
    337     const int ndims = input.dims();
    338     const int64 m = input.dim_size(ndims - 2);
    339     const int64 n = input.dim_size(ndims - 1);
    340     const int64 p = std::min(m, n);
    341 
    342     // Validate inputs.
    343     OP_REQUIRES_ASYNC(
    344         context, ndims >= 2,
    345         errors::InvalidArgument("Input must have rank >= 2, got ", ndims),
    346         done);
    347 
    348     // output tensors.
    349     Tensor* outputU = NULL;
    350     Tensor* outputS = NULL;
    351     Tensor* outputV = NULL;
    352 
    353     // compute  shapes
    354     TensorShape shapeRaw = input.shape();
    355     shapeRaw.RemoveLastDims(2);
    356     TensorShape shapeS = shapeRaw;
    357     TensorShape shapeU = shapeRaw;
    358     TensorShape shapeV = shapeRaw;
    359     shapeS.AddDim(p);
    360     if (compute_uv_) {
    361       if (full_matrices_) {
    362         shapeU.AddDim(m);
    363         shapeU.AddDim(m);
    364         shapeV.AddDim(n);
    365         shapeV.AddDim(n);
    366       } else {
    367         shapeU.AddDim(m);
    368         shapeU.AddDim(p);
    369         shapeV.AddDim(n);
    370         shapeV.AddDim(p);
    371       }
    372     } else {
    373       shapeU = TensorShape({0});
    374       shapeV = TensorShape({0});
    375     }
    376 
    377     // allocate output
    378     OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, shapeS, &outputS),
    379                          done);
    380     OP_REQUIRES_OK_ASYNC(context, context->allocate_output(1, shapeU, &outputU),
    381                          done);
    382     OP_REQUIRES_OK_ASYNC(context, context->allocate_output(2, shapeV, &outputV),
    383                          done);
    384 
    385     if (n == 0 || m == 0) {
    386       // If X is an empty matrix (0 rows, 0 col), X * X' == X.
    387       // Therefore, we return X.
    388       done();
    389       return;
    390     }
    391 
    392     // call implementations
    393     if (m >= n) {
    394       PerformSVD_MgeqN(context, done, m, n, p, input, outputS, outputU,
    395                        outputV);
    396     } else {
    397       PerformSVD_MlessN(context, done, m, n, p, input, outputS, outputU,
    398                         outputV);
    399     }
    400   }
    401 
    402  private:
    403   bool compute_uv_;
    404   bool full_matrices_;
    405 };
    406 
    407 // TODO: add support for complex types
    408 REGISTER_LINALG_OP_GPU("Svd", (SvdOpGpu<float>), float);
    409 REGISTER_LINALG_OP_GPU("Svd", (SvdOpGpu<double>), double);
    410 
    411 // Deprecated kernels.
    412 REGISTER_LINALG_OP_GPU("BatchSvd", (SvdOpGpu<float>), float);
    413 REGISTER_LINALG_OP_GPU("BatchSvd", (SvdOpGpu<double>), double);
    414 
    415 }  // namespace tensorflow
    416 
    417 #endif  // GOOGLE_CUDA
    418