1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_KERNELS_LINALG_OPS_COMMON_H_ 17 #define TENSORFLOW_KERNELS_LINALG_OPS_COMMON_H_ 18 19 // Classes to support linear algebra functionality, similar to the numpy.linalg 20 // module. Supports batch computation on several matrices at once, sharding the 21 // computations across different threads if necessary. 22 #include <algorithm> 23 24 #include "third_party/eigen3/Eigen/Core" 25 #include "tensorflow/core/framework/kernel_def_builder.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/tensor.h" 28 #include "tensorflow/core/framework/tensor_shape.h" 29 #include "tensorflow/core/framework/tensor_types.h" 30 #include "tensorflow/core/framework/types.h" 31 #include "tensorflow/core/lib/core/errors.h" 32 #include "tensorflow/core/lib/gtl/inlined_vector.h" 33 #include "tensorflow/core/platform/types.h" 34 #include "tensorflow/core/util/work_sharder.h" 35 36 namespace tensorflow { 37 38 // Base class for linear algebra operators. 39 template <typename Scalar> 40 class LinearAlgebraOp : public OpKernel { 41 public: 42 explicit LinearAlgebraOp(OpKernelConstruction* context) : OpKernel(context) {} 43 44 void Compute(OpKernelContext* context) override; 45 46 protected: 47 using TensorShapes = gtl::InlinedVector<TensorShape, 4>; 48 // Returns the number of leading inputs that are to be treated as matrix 49 // inputs. By default this is all the inputs. Derived classes can override 50 // this to tell the base class to ignore one or more trailing inputs. 51 virtual int NumMatrixInputs(const OpKernelContext* context) const { 52 return context->num_inputs(); 53 } 54 55 // Returns true if the number of inputs and their shapes are as expected. 56 // Many ops take a single square input matrix, so we provide that as a default 57 // implementation for convenience. 58 virtual void ValidateInputMatrixShapes( 59 OpKernelContext* context, const TensorShapes& input_matrix_shapes) const { 60 ValidateSingleSquareMatrix(context, input_matrix_shapes); 61 } 62 63 // Convenience validators for common cases: 64 // 65 // Validate op taking a single matrix A. 66 static void ValidateSingleMatrix(OpKernelContext* context, 67 const TensorShapes& input_matrix_shapes); 68 // Validate op taking a single square matrix A. 69 static void ValidateSingleSquareMatrix( 70 OpKernelContext* context, const TensorShapes& input_matrix_shapes); 71 // Validate op taking two matrices A and B that have the same number of rows. 72 static void ValidateSolver(OpKernelContext* context, 73 const TensorShapes& input_matrix_shapes); 74 // Validate op taking two matrices A and B that have the same number of rows 75 // and A is square. 76 static void ValidateSquareSolver(OpKernelContext* context, 77 const TensorShapes& input_matrix_shapes); 78 79 // Returns the output shapes of each individual matrix operation. Output 80 // matrices shapes must be rank 0, 1, or 2. Scalar outputs are rank 0. 81 // 82 // The derived class may return a number of shapes (N) less than 83 // context->num_outputs() (M) to indicate that a only leading subset of 84 // the outputs will be populated. In this case, a dummy scalar tensor with 85 // value zero will be return for the last M-N outputs. 86 // 87 // For many ops, the output dimensions are the same as the input dimensions, 88 // so we provide that as a default implementation for convenience. 89 virtual TensorShapes GetOutputMatrixShapes( 90 const TensorShapes& input_matrix_shapes) const { 91 return input_matrix_shapes; 92 } 93 94 // Returns the cost per matrix operation. This is used to determine the 95 // number of threads to use for parallelizing calls to ComputeMatrix in 96 // batch mode. Cost per unit is assumed to be roughly 1ns, based on comments 97 // in core/util/work_sharder.cc. Many linear algebra ops take roughly max(m,n) 98 // * min(m,n)^2, where the first input matrix is m-by-n. We provide that as a 99 // default implementation for convenience. 100 virtual int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const { 101 double m = static_cast<double>(input_matrix_shapes[0].dim_size(0)); 102 double n = static_cast<double>(input_matrix_shapes[0].dim_size(1)); 103 double cost = std::max(m, n) * std::min(m, n) * std::min(m, n); 104 return cost >= static_cast<double>(kint64max) ? kint64max 105 : static_cast<int64>(cost); 106 } 107 108 // Returns true if it is safe to forward (alias) input to output buffer 109 // and expect the kernel to perform the computation inplace. 110 virtual bool EnableInputForwarding() const { return true; } 111 112 using Matrix = 113 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; 114 using ConstMatrixMap = Eigen::Map<const Matrix>; 115 using MatrixMap = Eigen::Map<Matrix>; 116 using ConstMatrixMaps = gtl::InlinedVector<ConstMatrixMap, 4>; 117 using MatrixMaps = gtl::InlinedVector<MatrixMap, 4>; 118 using RealScalar = typename Eigen::NumTraits<Scalar>::Real; 119 120 // Performs a single matrix computation given input matrices, and 121 // stores the result in outputs. For batch operations, this will be called 122 // repeatedly for a single call to Compute() when multiple matrices exist in 123 // input Tensors with rank > 2. In this case the calls to ComputeMatrix are 124 // parallelized. The number of threads used is determined by a cost model from 125 // the value returned by GetCostPerUnit(). 126 virtual void ComputeMatrix(OpKernelContext* context, 127 const ConstMatrixMaps& inputs, 128 MatrixMaps* outputs) = 0; 129 130 private: 131 using TensorInputs = gtl::InlinedVector<const Tensor*, 4>; 132 using TensorOutputs = gtl::InlinedVector<Tensor*, 4>; 133 // This function maps 2-d slices (matrices) of the input and output tensors 134 // using Eigen::Map and calls ComputeMatrix implemented in terms of the 135 // Eigen::MatrixBase API by the derived class. 136 // 137 // The 'matrix_index' parameter specifies the index of the matrix to be used 138 // from each input tensor, and the index of the matrix to be written to each 139 // output tensor. The input matrices are in row major order, and located at 140 // the memory addresses 141 // inputs[i].flat<Scalar>().data() + 142 // matrix_index * input_matrix_shapes[i].num_elements() 143 // for i in 0...inputs.size()-1. 144 // The output matrices are in row major order, and located at the memory 145 // address 146 // outputs[i]->flat<Scalar>().data() + 147 // matrix_index * output_matrix_shapes[i].num_elements(). 148 // for i in 0...outputs.size()-1. 149 // 150 void ComputeTensorSlice(OpKernelContext* context, int64 matrix_index, 151 const TensorInputs& inputs, 152 const TensorShapes& input_matrix_shapes, 153 const TensorOutputs& outputs, 154 const TensorShapes& output_matrix_shapes); 155 156 void AnalyzeInputs(OpKernelContext* context, TensorInputs* inputs, 157 TensorShapes* input_matrix_shapes, 158 TensorShape* batch_shape); 159 160 void PrepareOutputs(OpKernelContext* context, 161 const TensorShapes& input_matrix_shapes, 162 const TensorShape& batch_shape, TensorOutputs* outputs, 163 TensorShapes* output_matrix_shapes); 164 }; 165 166 // Declare LinearAlgebraOp, which is explicitly instantiated in 167 // linalg_ops_common.cc for float, double, complex64, and complex128. 168 extern template class LinearAlgebraOp<float>; 169 extern template class LinearAlgebraOp<double>; 170 extern template class LinearAlgebraOp<complex64>; 171 extern template class LinearAlgebraOp<complex128>; 172 173 } // namespace tensorflow 174 175 #define INHERIT_LINALG_TYPEDEFS(Scalar) \ 176 typedef LinearAlgebraOp<Scalar> Base; \ 177 using RealScalar = typename Eigen::NumTraits<Scalar>::Real; \ 178 using Matrix = typename Base::Matrix; \ 179 using MatrixMap = typename Base::MatrixMap; \ 180 using MatrixMaps = typename Base::MatrixMaps; \ 181 using ConstMatrixMap = typename Base::ConstMatrixMap; \ 182 using ConstMatrixMaps = typename Base::ConstMatrixMaps; \ 183 using TensorShapes = typename Base::TensorShapes; 184 185 #define REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar) \ 186 REGISTER_KERNEL_BUILDER( \ 187 Name(OpName).Device(DEVICE_CPU).TypeConstraint<Scalar>("T"), OpClass) 188 189 #define REGISTER_LINALG_OP_GPU(OpName, OpClass, Scalar) \ 190 REGISTER_KERNEL_BUILDER( \ 191 Name(OpName).Device(DEVICE_GPU).TypeConstraint<Scalar>("T"), OpClass) 192 193 // Deprecated, use one of the device-specific macros above. 194 #define REGISTER_LINALG_OP(OpName, OpClass, Scalar) \ 195 REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar) 196 197 #endif // TENSORFLOW_KERNELS_LINALG_OPS_COMMON_H_ 198