Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_LINALG_OPS_COMMON_H_
     17 #define TENSORFLOW_KERNELS_LINALG_OPS_COMMON_H_
     18 
     19 // Classes to support linear algebra functionality, similar to the numpy.linalg
     20 // module. Supports batch computation on several matrices at once, sharding the
     21 // computations across different threads if necessary.
     22 #include <algorithm>
     23 
     24 #include "third_party/eigen3/Eigen/Core"
     25 #include "tensorflow/core/framework/kernel_def_builder.h"
     26 #include "tensorflow/core/framework/op_kernel.h"
     27 #include "tensorflow/core/framework/tensor.h"
     28 #include "tensorflow/core/framework/tensor_shape.h"
     29 #include "tensorflow/core/framework/tensor_types.h"
     30 #include "tensorflow/core/framework/types.h"
     31 #include "tensorflow/core/lib/core/errors.h"
     32 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     33 #include "tensorflow/core/platform/types.h"
     34 #include "tensorflow/core/util/work_sharder.h"
     35 
     36 namespace tensorflow {
     37 
     38 // Base class for linear algebra operators.
     39 template <typename Scalar>
     40 class LinearAlgebraOp : public OpKernel {
     41  public:
     42   explicit LinearAlgebraOp(OpKernelConstruction* context) : OpKernel(context) {}
     43 
     44   void Compute(OpKernelContext* context) override;
     45 
     46  protected:
     47   using TensorShapes = gtl::InlinedVector<TensorShape, 4>;
     48   // Returns the number of leading inputs that are to be treated as matrix
     49   // inputs. By default this is all the inputs. Derived classes can override
     50   // this to tell the base class to ignore one or more trailing inputs.
     51   virtual int NumMatrixInputs(const OpKernelContext* context) const {
     52     return context->num_inputs();
     53   }
     54 
     55   // Returns true if the number of inputs and their shapes are as expected.
     56   // Many ops take a single square input matrix, so we provide that as a default
     57   // implementation for convenience.
     58   virtual void ValidateInputMatrixShapes(
     59       OpKernelContext* context, const TensorShapes& input_matrix_shapes) const {
     60     ValidateSingleSquareMatrix(context, input_matrix_shapes);
     61   }
     62 
     63   // Convenience validators for common cases:
     64   //
     65   // Validate op taking a single matrix A.
     66   static void ValidateSingleMatrix(OpKernelContext* context,
     67                                    const TensorShapes& input_matrix_shapes);
     68   // Validate op taking a single square matrix A.
     69   static void ValidateSingleSquareMatrix(
     70       OpKernelContext* context, const TensorShapes& input_matrix_shapes);
     71   // Validate op taking two matrices A and B that have the same number of rows.
     72   static void ValidateSolver(OpKernelContext* context,
     73                              const TensorShapes& input_matrix_shapes);
     74   // Validate op taking two matrices A and B that have the same number of rows
     75   // and A is square.
     76   static void ValidateSquareSolver(OpKernelContext* context,
     77                                    const TensorShapes& input_matrix_shapes);
     78 
     79   // Returns the output shapes of each individual matrix operation. Output
     80   // matrices shapes must be rank 0, 1, or 2. Scalar outputs are rank 0.
     81   //
     82   // The derived class may return a number of shapes (N) less than
     83   // context->num_outputs() (M) to indicate that a only leading subset of
     84   // the outputs will be populated. In this case, a dummy scalar tensor with
     85   // value zero will be return for the last M-N outputs.
     86   //
     87   // For many ops, the output dimensions are the same as the input dimensions,
     88   // so we provide that as a default implementation for convenience.
     89   virtual TensorShapes GetOutputMatrixShapes(
     90       const TensorShapes& input_matrix_shapes) const {
     91     return input_matrix_shapes;
     92   }
     93 
     94   // Returns the cost per matrix operation. This is used to determine the
     95   // number of threads to use for parallelizing calls to ComputeMatrix in
     96   // batch mode. Cost per unit is assumed to be roughly 1ns, based on comments
     97   // in core/util/work_sharder.cc. Many linear algebra ops take roughly max(m,n)
     98   // * min(m,n)^2, where the first input matrix is m-by-n. We provide that as a
     99   // default implementation for convenience.
    100   virtual int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const {
    101     double m = static_cast<double>(input_matrix_shapes[0].dim_size(0));
    102     double n = static_cast<double>(input_matrix_shapes[0].dim_size(1));
    103     double cost = std::max(m, n) * std::min(m, n) * std::min(m, n);
    104     return cost >= static_cast<double>(kint64max) ? kint64max
    105                                                   : static_cast<int64>(cost);
    106   }
    107 
    108   // Returns true if it is safe to forward (alias) input to output buffer
    109   // and expect the kernel to perform the computation inplace.
    110   virtual bool EnableInputForwarding() const { return true; }
    111 
    112   using Matrix =
    113       Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
    114   using ConstMatrixMap = Eigen::Map<const Matrix>;
    115   using MatrixMap = Eigen::Map<Matrix>;
    116   using ConstMatrixMaps = gtl::InlinedVector<ConstMatrixMap, 4>;
    117   using MatrixMaps = gtl::InlinedVector<MatrixMap, 4>;
    118   using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
    119 
    120   // Performs a single matrix computation given input matrices, and
    121   // stores the result in outputs. For batch operations, this will be called
    122   // repeatedly for a single call to Compute() when multiple matrices exist in
    123   // input Tensors with rank > 2. In this case the calls to ComputeMatrix are
    124   // parallelized. The number of threads used is determined by a cost model from
    125   // the value returned by GetCostPerUnit().
    126   virtual void ComputeMatrix(OpKernelContext* context,
    127                              const ConstMatrixMaps& inputs,
    128                              MatrixMaps* outputs) = 0;
    129 
    130  private:
    131   using TensorInputs = gtl::InlinedVector<const Tensor*, 4>;
    132   using TensorOutputs = gtl::InlinedVector<Tensor*, 4>;
    133   // This function maps 2-d slices (matrices) of the input and output tensors
    134   // using Eigen::Map and calls ComputeMatrix implemented in terms of the
    135   // Eigen::MatrixBase API by the derived class.
    136   //
    137   // The 'matrix_index' parameter specifies the index of the matrix to be used
    138   // from each input tensor, and the index of the matrix to be written to each
    139   // output tensor. The input matrices are in row major order, and located at
    140   // the memory addresses
    141   //   inputs[i].flat<Scalar>().data() +
    142   //   matrix_index * input_matrix_shapes[i].num_elements()
    143   // for i in 0...inputs.size()-1.
    144   // The output matrices are in row major order, and located at the memory
    145   // address
    146   //   outputs[i]->flat<Scalar>().data() +
    147   //   matrix_index * output_matrix_shapes[i].num_elements().
    148   // for i in 0...outputs.size()-1.
    149   //
    150   void ComputeTensorSlice(OpKernelContext* context, int64 matrix_index,
    151                           const TensorInputs& inputs,
    152                           const TensorShapes& input_matrix_shapes,
    153                           const TensorOutputs& outputs,
    154                           const TensorShapes& output_matrix_shapes);
    155 
    156   void AnalyzeInputs(OpKernelContext* context, TensorInputs* inputs,
    157                      TensorShapes* input_matrix_shapes,
    158                      TensorShape* batch_shape);
    159 
    160   void PrepareOutputs(OpKernelContext* context,
    161                       const TensorShapes& input_matrix_shapes,
    162                       const TensorShape& batch_shape, TensorOutputs* outputs,
    163                       TensorShapes* output_matrix_shapes);
    164 };
    165 
    166 // Declare LinearAlgebraOp, which is explicitly instantiated in
    167 // linalg_ops_common.cc for float, double, complex64, and complex128.
    168 extern template class LinearAlgebraOp<float>;
    169 extern template class LinearAlgebraOp<double>;
    170 extern template class LinearAlgebraOp<complex64>;
    171 extern template class LinearAlgebraOp<complex128>;
    172 
    173 }  // namespace tensorflow
    174 
    175 #define INHERIT_LINALG_TYPEDEFS(Scalar)                       \
    176   typedef LinearAlgebraOp<Scalar> Base;                       \
    177   using RealScalar = typename Eigen::NumTraits<Scalar>::Real; \
    178   using Matrix = typename Base::Matrix;                       \
    179   using MatrixMap = typename Base::MatrixMap;                 \
    180   using MatrixMaps = typename Base::MatrixMaps;               \
    181   using ConstMatrixMap = typename Base::ConstMatrixMap;       \
    182   using ConstMatrixMaps = typename Base::ConstMatrixMaps;     \
    183   using TensorShapes = typename Base::TensorShapes;
    184 
    185 #define REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar) \
    186   REGISTER_KERNEL_BUILDER(                              \
    187       Name(OpName).Device(DEVICE_CPU).TypeConstraint<Scalar>("T"), OpClass)
    188 
    189 #define REGISTER_LINALG_OP_GPU(OpName, OpClass, Scalar) \
    190   REGISTER_KERNEL_BUILDER(                              \
    191       Name(OpName).Device(DEVICE_GPU).TypeConstraint<Scalar>("T"), OpClass)
    192 
    193 // Deprecated, use one of the device-specific macros above.
    194 #define REGISTER_LINALG_OP(OpName, OpClass, Scalar) \
    195   REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar)
    196 
    197 #endif  // TENSORFLOW_KERNELS_LINALG_OPS_COMMON_H_
    198