     16 // See docs in ../ops/linalg_ops.cc.
     17 // TODO(shamanDevel): Enable complex inputs. This will require a specialization
     18 //                    of Gesvd for complex inputs as well as a new kernel
     19 //                    definition to output the singular values as reals
     20 //                    instead of complex values. The current CPU implementation
     21 //                    outputs the singular values as complex values and then
     22 //                    casts them to reals in the python wrapper.
     23 // TODO(rmlarsen/shamanDevel): This could use a bit of cleanup. We don't need to
     24 // pass quite as many raw pointers around. Would also be nice to reduce code
     25 // duplication.
     27 #if GOOGLE_CUDA
     28 #define EIGEN_USE_GPU
     30 #include <algorithm>
     31 #include <vector>
     33 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     34 #include "tensorflow/core/framework/kernel_def_builder.h"
     35 #include "tensorflow/core/framework/op_kernel.h"
     36 #include "tensorflow/core/framework/register_types.h"
     37 #include "tensorflow/core/framework/tensor_shape.h"
     38 #include "tensorflow/core/framework/types.h"
     39 #include "tensorflow/core/kernels/cuda_solvers.h"
     40 #include "tensorflow/core/kernels/linalg_ops_common.h"
     41 #include "tensorflow/core/kernels/transpose_functor.h"
     42 #include "tensorflow/core/lib/core/errors.h"
     43 #include "tensorflow/core/platform/logging.h"
     44 #include "tensorflow/core/platform/stream_executor.h"
     45 #include "tensorflow/core/platform/types.h"
     46 #include "tensorflow/core/util/cuda_kernel_helper.h"
     48 namespace tensorflow {
     50 static const char kErrMsg[] =
     51     "Singular Value Decomposition was not successful. The input might not be "
     52     "valid.";
     54 typedef Eigen::GpuDevice GPUDevice;
     56 namespace {
     57 // This kernel computes the reduction
     58 // V' = sum_i (M_i * U_i,1 * S_i).
     59 // The result is stored in V[batch] and has the same sign as the
     60 // real value of V (which should be computed)
     61 template <class Scalar>
     62 __global__ void ComputeValueOfVKernel(Cuda2DLaunchConfig config, int64 m,
     63                                       int64 ldu, const Scalar* M,
     64                                       const Scalar* U, const Scalar* S,
     65                                       Scalar* V) {
     66   CUDA_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count.x, X) {
     67     CUDA_AXIS_KERNEL_LOOP(i, config.virtual_thread_count.y, Y) {
     68       Scalar v = M[i + m * batch] * U[ldu * (i + m * batch)] * S[batch];
     69       CudaAtomicAdd(V + batch, v);
     70     }
     71   }
     72 }
     74 // Extracts the sign of V
     75 // V[i] = V[i]>=0 ? 1 : 0
     76 template <class Scalar>
     77 __global__ void ExtractSignOfVKernel(CudaLaunchConfig config, Scalar* V) {
     78   CUDA_1D_KERNEL_LOOP(i, config.virtual_thread_count) {
     79     V[i] = V[i] >= 0 ? Scalar(1) : Scalar(-1);
     80   }
     81 }
     82 }  // namespace
     84 // Scalar: The input scalar type (can be complex)
     85 template <class Scalar>
     86 class SvdOpGpu : public AsyncOpKernel {
     87  public:
     88   using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
     90   explicit SvdOpGpu(OpKernelConstruction* context) : AsyncOpKernel(context) {
     91     OP_REQUIRES_OK(context, context->GetAttr("compute_uv", &compute_uv_));
     92     OP_REQUIRES_OK(context, context->GetAttr("full_matrices", &full_matrices_));
     93   }
     95   void RunSVD(OpKernelContext* context, DoneCallback done, int64 m, int64 n,
     96               int64 p, int64 batch_size, Scalar* input_ptr,
     97               RealScalar* outputS_ptr, Scalar* outputU_ptr,
     98               Scalar* outputVT_ptr, int* dev_info_ptr, CudaSolver* solver) {
     99     // Save the input matrix
    100     // Needed for the n=1 fix, see below, since SVD destroys the input
    101     Tensor input_copy;
    102     if (compute_uv_ && n == 1) {
    103       OP_REQUIRES_OK_ASYNC(context,
    104                            solver->allocate_scoped_tensor(
    105                                DataTypeToEnum<Scalar>::v(),
    106                                TensorShape({batch_size, m}), &input_copy),
    107                            done);
    108       const GPUDevice& d = context->eigen_device<GPUDevice>();
    109       d.memcpy(input_copy.flat<Scalar>().data(), input_ptr,
    110                batch_size * m * sizeof(Scalar));
    111     }
    113     for (int64 batch = 0; batch < batch_size; ++batch) {
    114       Scalar* input = input_ptr + batch * m * n;
    115       RealScalar* outputS = outputS_ptr + batch * p;
    116       Scalar* outputU = NULL;
    117       Scalar* outputVT = NULL;
    118       char jobu = 'N';
    119       char jobvt = 'N';
    121       if (compute_uv_) {
    122         if (full_matrices_) {
    123           outputU = outputU_ptr + batch * m * m;
    124           outputVT = outputVT_ptr + batch * n * n;
    125           jobu = 'A';
    126           jobvt = 'A';
    127         } else {
    128           outputU = outputU_ptr + batch * m * p;
    129           outputVT = outputVT_ptr + batch * n * p;
    130           jobu = 'S';
    131           jobvt = 'S';
    132         }
    133       }
    135       OP_REQUIRES_OK_ASYNC(
    136           context,
    137           solver->Gesvd(jobu, jobvt, m, n, input, m, outputS, outputU, m,
    138                         outputVT, n, dev_info_ptr + batch),
    139           done);
    140     }
    142     // This is a bug in cuSolver:
    143     // If n is one, then outputVT only contains zeros instead of ones.
    144     // Hence, I need to fill outputVT manually
    145     // The question is: +1 or -1?
    146     // -> Compute U*S and compare sign against M
    147     // But because S is zero except for the first entry, the multiplication
    148     // simplifies a lot.
    149     // However, what happens if M contains zeros? At these indices, it is
    150     // impossible to determine the value of V.
    151     // -> Compute V for all rows in M to cope for zeros.
    152     // 1. V' = sum_i (M_i * U_i,1 * S_i)
    153     // 2. V = {1, V'>=0, -1, V'<0}
    154     // TODO: what is with complex values?
    155     if (compute_uv_ && n == 1) {
    156       // 1. compute the (batched) sum
    157       const GPUDevice& d = context->eigen_device<GPUDevice>();
    158       d.memset(outputVT_ptr, 0, batch_size * sizeof(Scalar));
    159       Cuda2DLaunchConfig cfg2D = GetCuda2DLaunchConfig(batch_size, m, d);
    160       ComputeValueOfVKernel<<<cfg2D.block_count, cfg2D.thread_per_block, 0,
    161                               d.stream()>>>(
    162           cfg2D, m, full_matrices_ ? m : p, input_copy.flat<Scalar>().data(),
    163           outputU_ptr, outputS_ptr, outputVT_ptr);
    164       // 2. clamp V to -1 or +1
    165       CudaLaunchConfig cfg1D = GetCudaLaunchConfig(batch_size, d);
    166       ExtractSignOfVKernel<<<cfg1D.block_count, cfg1D.thread_per_block, 0,
    167                              d.stream()>>>(cfg1D, outputVT_ptr);
    168     }
    169   }
    171   void CheckResult(OpKernelContext* context, DoneCallback done,
    172                    const std::vector<DeviceLapackInfo>& dev_info,
    173                    std::unique_ptr<CudaSolver> solver) {
    174     auto info_checker = [context, done](
    175                             const Status& status,
    176                             const std::vector<HostLapackInfo>& /* unused */) {
    177       Status full_status = status;
    178       if (!full_status.ok()) {
    179         full_status.Update(errors::InvalidArgument(kErrMsg));
    180       }
    181       OP_REQUIRES_OK_ASYNC(context, full_status, done);
    182       done();
    183     };
    185     CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info,
    186                                                     std::move(info_checker));
    187   }
    189   // The SVD if m >= n
    190   // TODO: can the two cases (MgeqN and MlessN) be simplified,
    191   //   common boilerplate be reduced, or even combined in one method?
    192   void PerformSVD_MgeqN(OpKernelContext* context, DoneCallback done, int64 m,
    193                         int64 n, int64 p, const Tensor& M, Tensor* S, Tensor* U,
    194                         Tensor* V) {
    195     TensorShape shapeRaw = M.shape();
    196     shapeRaw.RemoveLastDims(2);
    198     // Transpose M, because cuSolver expects it to be column-major
    199     TensorShape input_shape = shapeRaw;
    200     input_shape.AddDim(n);
    201     input_shape.AddDim(m);
    202     Tensor input_copy;
    203     // TODO(rmlarsen): Convert to std::make_unique when available.
    204     std::unique_ptr<CudaSolver> solver(new CudaSolver(context));
    206         context,
    207         solver->allocate_scoped_tensor(M.dtype(), input_shape, &input_copy),
    208         done);
    209     auto device = context->eigen_device<GPUDevice>();
    210     OP_REQUIRES_OK_ASYNC(context, DoMatrixTranspose(device, M, &input_copy),
    211                          done);
    213     // I need to transpose U at the end
    214     // Not V, because cuSolver work column-major
    215     Tensor u_copy;
    216     if (compute_uv_) {
    217       TensorShape u_shape;
    218       if (full_matrices_) {
    219         u_shape = U->shape();
    220       } else {
    221         u_shape = shapeRaw;
    222         u_shape.AddDim(p);
    223         u_shape.AddDim(m);
    224       }
    225       OP_REQUIRES_OK_ASYNC(
    226           context, solver->allocate_scoped_tensor(U->dtype(), u_shape, &u_copy),
    227           done);
    228     }
    230     // get the pointers to the data
    231     Scalar* input_ptr;
    232     RealScalar* outputS_ptr;
    233     Scalar* outputU_ptr = NULL;
    234     Scalar* outputV_ptr = NULL;
    235     auto input_reshaped = input_copy.template flat_inner_dims<Scalar, 3>();
    236     input_ptr = input_reshaped.data();
    237     outputS_ptr = S->template flat_inner_dims<RealScalar, 2>().data();
    238     if (compute_uv_) {
    239       outputU_ptr = u_copy.template flat_inner_dims<Scalar, 3>().data();
    240       outputV_ptr = V->template flat_inner_dims<Scalar, 3>().data();
    241     }
    243     // call the SVD
    244     const int64 batch_size = input_reshaped.dimension(0);
    245     std::vector<DeviceLapackInfo> dev_info;
    246     dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "gesvd"));
    247     RunSVD(context, done, m, n, p, batch_size, input_ptr, outputS_ptr,
    248            outputU_ptr, outputV_ptr, dev_info.back().mutable_data(),
    249            solver.get());
    251     // Transpose U
    252     if (compute_uv_) {
    253       OP_REQUIRES_OK_ASYNC(context, DoMatrixTranspose(device, u_copy, U), done);
    254     }
    256     // now check if the SVD operation succeeded or not
    257     CheckResult(context, std::move(done), dev_info, std::move(solver));
    258   }
    260   // The SVD if m < n
    261   void PerformSVD_MlessN(OpKernelContext* context, DoneCallback done, int64 m,
    262                          int64 n, int64 p, const Tensor& M, Tensor* S,
    263                          Tensor* U, Tensor* V) {
    264     // Perform the SVD on M'
    266     // Reuse the input buffer or make a copy for the SVD depending on whether
    267     // this op owns the input buffer exclusively. This is needed because the
    268     // SVD modifies the input
    269     // TODO(rmlarsen): Convert to std::make_unique when available.
    270     std::unique_ptr<CudaSolver> solver(new CudaSolver(context));
    271     Tensor input_copy;
    273         context,
    274         solver->forward_input_or_allocate_scoped_tensor(
    275             {0}, DataTypeToEnum<Scalar>::value, M.shape(), &input_copy),
    276         done);
    278     if (!M.SharesBufferWith(input_copy)) {
    279       const GPUDevice& d = context->eigen_device<GPUDevice>();
    280       d.memcpy(input_copy.flat<Scalar>().data(), M.flat<Scalar>().data(),
    281                M.NumElements() * sizeof(Scalar));
    282     }
    284     // I need to transpose V at the end
    285     Tensor v_copy;
    286     if (compute_uv_) {
    287       TensorShape v_shape;
    288       if (full_matrices_) {
    289         v_shape = V->shape();
    290       } else {
    291         TensorShape shapeRaw = M.shape();
    292         shapeRaw.RemoveLastDims(2);
    293         v_shape = shapeRaw;
    294         v_shape.AddDim(p);
    295         v_shape.AddDim(n);
    296       }
    297       OP_REQUIRES_OK_ASYNC(
    298           context, solver->allocate_scoped_tensor(V->dtype(), v_shape, &v_copy),
    299           done);
    300     }
    302     // get the pointers to the data
    303     Scalar* input_ptr;
    304     RealScalar* outputS_ptr;
    305     Scalar* outputU_ptr = NULL;
    306     Scalar* outputV_ptr = NULL;
    307     auto input_reshaped = input_copy.template flat_inner_dims<Scalar, 3>();
    308     input_ptr = input_reshaped.data();
    309     outputS_ptr = S->template flat_inner_dims<RealScalar, 2>().data();
    310     if (compute_uv_) {
    311       // Note that U and V are flipped
    312       outputU_ptr = v_copy.template flat_inner_dims<Scalar, 3>().data();
    313       outputV_ptr = U->template flat_inner_dims<Scalar, 3>().data();
    314     }
    316     // call the SVD
    317     const int64 batch_size = input_reshaped.dimension(0);
    318     std::vector<DeviceLapackInfo> dev_info;
    319     dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "gesvd"));
    320     // Note that m and n are flipped
    321     RunSVD(context, done, n, m, p, batch_size, input_ptr, outputS_ptr,
    322            outputU_ptr, outputV_ptr, dev_info.back().mutable_data(),
    323            solver.get());
    325     // Transpose V
    326     if (compute_uv_) {
    327       auto device = context->eigen_device<GPUDevice>();
    328       OP_REQUIRES_OK_ASYNC(context, DoMatrixTranspose(device, v_copy, V), done);
    329     }
    331     // now check if the SVD operation succeeded or not
    332     CheckResult(context, std::move(done), dev_info, std::move(solver));
    333   }
    335   void ComputeAsync(OpKernelContext* context, DoneCallback done) final {
    336     const Tensor& input = context->input(0);
    337     const int ndims = input.dims();
    338     const int64 m = input.dim_size(ndims - 2);
    339     const int64 n = input.dim_size(ndims - 1);
    340     const int64 p = std::min(m, n);
    342     // Validate inputs.
    343     OP_REQUIRES_ASYNC(
    344         context, ndims >= 2,
    345         errors::InvalidArgument("Input must have rank >= 2, got ", ndims),
    346         done);
    348     // output tensors.
    349     Tensor* outputU = NULL;
    350     Tensor* outputS = NULL;
    351     Tensor* outputV = NULL;
    353     // compute  shapes
    354     TensorShape shapeRaw = input.shape();
    355     shapeRaw.RemoveLastDims(2);
    356     TensorShape shapeS = shapeRaw;
    357     TensorShape shapeU = shapeRaw;
    358     TensorShape shapeV = shapeRaw;
    359     shapeS.AddDim(p);
    360     if (compute_uv_) {
    361       if (full_matrices_) {
    362         shapeU.AddDim(m);
    363         shapeU.AddDim(m);
    364         shapeV.AddDim(n);
    365         shapeV.AddDim(n);
    366       } else {
    367         shapeU.AddDim(m);
    368         shapeU.AddDim(p);
    369         shapeV.AddDim(n);
    370         shapeV.AddDim(p);
    371       }
    372     } else {
    373       shapeU = TensorShape({0});
    374       shapeV = TensorShape({0});
    375     }
    377     // allocate output
    378     OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, shapeS, &outputS),
    379                          done);
    380     OP_REQUIRES_OK_ASYNC(context, context->allocate_output(1, shapeU, &outputU),
    381                          done);
    382     OP_REQUIRES_OK_ASYNC(context, context->allocate_output(2, shapeV, &outputV),
    383                          done);
    385     if (n == 0 || m == 0) {
    386       // If X is an empty matrix (0 rows, 0 col), X * X' == X.
    387       // Therefore, we return X.
    388       done();
    389       return;
    390     }
    392     // call implementations
    393     if (m >= n) {
    394       PerformSVD_MgeqN(context, done, m, n, p, input, outputS, outputU,
    395                        outputV);
    396     } else {
    397       PerformSVD_MlessN(context, done, m, n, p, input, outputS, outputU,
    398                         outputV);
    399     }
    400   }
    402  private:
    403   bool compute_uv_;
    404   bool full_matrices_;
    405 };
    407 // TODO: add support for complex types
    408 REGISTER_LINALG_OP_GPU("Svd", (SvdOpGpu<float>), float);
    409 REGISTER_LINALG_OP_GPU("Svd", (SvdOpGpu<double>), double);
    411 // Deprecated kernels.
    412 REGISTER_LINALG_OP_GPU("BatchSvd", (SvdOpGpu<float>), float);
    413 REGISTER_LINALG_OP_GPU("BatchSvd", (SvdOpGpu<double>), double);
    415 }  // namespace tensorflow
    417 #endif  // GOOGLE_CUDA