1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/kernels/linalg_ops_common.h" 17 18 #include <utility> 19 20 #include "third_party/eigen3/Eigen/Core" 21 #include "tensorflow/core/framework/device_base.h" 22 #include "tensorflow/core/framework/kernel_def_builder.h" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/tensor_shape.h" 25 #include "tensorflow/core/lib/core/errors.h" 26 #include "tensorflow/core/platform/logging.h" 27 #include "tensorflow/core/platform/types.h" 28 29 namespace tensorflow { 30 31 // static 32 template <typename Scalar> 33 void LinearAlgebraOp<Scalar>::ValidateSingleMatrix( 34 OpKernelContext* context, const TensorShapes& input_matrix_shapes) { 35 OP_REQUIRES(context, input_matrix_shapes.size() == 1, 36 errors::InvalidArgument("Expected a single input matrix, got %d.", 37 input_matrix_shapes.size())); 38 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_matrix_shapes[0]), 39 errors::InvalidArgument("Input must be a matrix.")); 40 } 41 42 // static 43 template <typename Scalar> 44 void LinearAlgebraOp<Scalar>::ValidateSingleSquareMatrix( 45 OpKernelContext* context, const TensorShapes& input_matrix_shapes) { 46 OP_REQUIRES(context, input_matrix_shapes.size() == 1, 47 errors::InvalidArgument("Expected a single input matrix, got %d.", 48 input_matrix_shapes.size())); 49 OP_REQUIRES(context, TensorShapeUtils::IsSquareMatrix(input_matrix_shapes[0]), 50 errors::InvalidArgument("Input matrix must be square.")); 51 } 52 53 // static 54 template <typename Scalar> 55 void LinearAlgebraOp<Scalar>::ValidateSolver( 56 OpKernelContext* context, const TensorShapes& input_matrix_shapes) { 57 OP_REQUIRES(context, input_matrix_shapes.size() == 2, 58 errors::InvalidArgument("Expected two input matrices, got %d.", 59 input_matrix_shapes.size())); 60 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_matrix_shapes[0]), 61 errors::InvalidArgument("First input (lhs) must be a matrix.")); 62 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_matrix_shapes[1]), 63 errors::InvalidArgument("Second input (rhs) must be a matrix.")); 64 OP_REQUIRES( 65 context, 66 input_matrix_shapes[0].dim_size(0) == input_matrix_shapes[1].dim_size(0), 67 errors::InvalidArgument("Input matrix and rhs are incompatible.")); 68 } 69 70 // static 71 template <typename Scalar> 72 void LinearAlgebraOp<Scalar>::ValidateSquareSolver( 73 OpKernelContext* context, const TensorShapes& input_matrix_shapes) { 74 OP_REQUIRES(context, input_matrix_shapes.size() == 2, 75 errors::InvalidArgument("Expected two input matrices, got %d.", 76 input_matrix_shapes.size())); 77 OP_REQUIRES( 78 context, TensorShapeUtils::IsSquareMatrix(input_matrix_shapes[0]), 79 errors::InvalidArgument("First input (lhs) must be a square matrix.")); 80 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_matrix_shapes[1]), 81 errors::InvalidArgument("Second input (rhs) must be a matrix.")); 82 OP_REQUIRES( 83 context, 84 input_matrix_shapes[0].dim_size(0) == input_matrix_shapes[1].dim_size(0), 85 errors::InvalidArgument("Input matrix and rhs are incompatible.")); 86 } 87 88 template <typename Scalar> 89 void LinearAlgebraOp<Scalar>::Compute(OpKernelContext* context) { 90 TensorInputs inputs; 91 TensorShapes input_matrix_shapes; 92 TensorShape batch_shape; 93 AnalyzeInputs(context, &inputs, &input_matrix_shapes, &batch_shape); 94 95 TensorShapes output_matrix_shapes; 96 TensorOutputs outputs; 97 PrepareOutputs(context, input_matrix_shapes, batch_shape, &outputs, 98 &output_matrix_shapes); 99 100 // Process the individual matrix problems in parallel using a threadpool. 101 auto shard = [this, &inputs, &input_matrix_shapes, &outputs, 102 &output_matrix_shapes, context](int64 begin, int64 end) { 103 for (int64 i = begin; i < end; ++i) { 104 ComputeTensorSlice(context, i, inputs, input_matrix_shapes, outputs, 105 output_matrix_shapes); 106 } 107 }; 108 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 109 Shard(worker_threads.num_threads, worker_threads.workers, 110 batch_shape.num_elements(), GetCostPerUnit(input_matrix_shapes), shard); 111 } 112 113 template <typename Scalar> 114 void LinearAlgebraOp<Scalar>::AnalyzeInputs(OpKernelContext* context, 115 TensorInputs* inputs, 116 TensorShapes* input_matrix_shapes, 117 TensorShape* batch_shape) { 118 int input_rank = -1; 119 for (int i = 0; i < NumMatrixInputs(context); ++i) { 120 const Tensor& in = context->input(i); 121 if (i == 0) { 122 input_rank = in.dims(); 123 OP_REQUIRES( 124 context, input_rank >= 2, 125 errors::InvalidArgument("Input tensor ", i, 126 " must have rank >= 2, got ", input_rank)); 127 // If the tensor rank is greater than 2, we consider the inner-most 128 // dimensions as matrices, and loop over all the other outer ("batch") 129 // dimensions to compute the results. 130 for (int dim = 0; dim < input_rank - 2; ++dim) { 131 batch_shape->AddDim(in.dim_size(dim)); 132 } 133 } else { 134 // Make sure that all inputs have the same rank and outer dimensions. 135 OP_REQUIRES(context, input_rank == in.dims(), 136 errors::InvalidArgument( 137 "All input tensors must have the same rank.")); 138 for (int dim = 0; dim < input_rank - 2; ++dim) { 139 OP_REQUIRES( 140 context, in.dim_size(dim) == batch_shape->dim_size(dim), 141 errors::InvalidArgument( 142 "All input tensors must have the same outer dimensions.")); 143 } 144 } 145 146 const int row_dimension = input_rank - 2; 147 const int col_dimension = input_rank - 1; 148 const int64 num_rows = in.dim_size(row_dimension); 149 const int64 num_cols = in.dim_size(col_dimension); 150 input_matrix_shapes->emplace_back( 151 std::initializer_list<int64>({num_rows, num_cols})); 152 inputs->emplace_back(&in); 153 } 154 // Have the derived class validate that the inputs are as expected. 155 ValidateInputMatrixShapes(context, *input_matrix_shapes); 156 } 157 158 template <typename Scalar> 159 void LinearAlgebraOp<Scalar>::PrepareOutputs( 160 OpKernelContext* context, const TensorShapes& input_matrix_shapes, 161 const TensorShape& batch_shape, TensorOutputs* outputs, 162 TensorShapes* output_matrix_shapes) { 163 // Get shape for each of the matrix outputs produced by the derived class. 164 *output_matrix_shapes = GetOutputMatrixShapes(input_matrix_shapes); 165 const int num_outputs = output_matrix_shapes->size(); 166 167 // Make sure the number of op outputs is what the derived class expects. 168 OP_REQUIRES( 169 context, num_outputs <= context->num_outputs(), 170 errors::Internal( 171 "Derived class expected more outputs (%d) that the op has (%d).", 172 num_outputs, context->num_outputs())); 173 174 // Allocate outputs. 175 std::set<int> unused_inputs; 176 for (int input_idx = 0; input_idx < context->num_inputs(); ++input_idx) { 177 unused_inputs.insert(input_idx); 178 } 179 for (int output_idx = 0; output_idx < context->num_outputs(); ++output_idx) { 180 TensorShape output_tensor_shape({}); 181 if (output_idx < num_outputs) { 182 // This output is used, set up output shape and allocate it. 183 const TensorShape& output_matrix_shape = 184 output_matrix_shapes->at(output_idx); 185 OP_REQUIRES(context, output_matrix_shape.dims() <= 2, 186 errors::InvalidArgument( 187 "Rank of matrix output no. %d must be 0, 1 or 2, got %d.", 188 output_idx, output_matrix_shape.dims())); 189 190 // The final output has the shape of the outer batch dimensions 191 // concatenated with the output_matrix_shape (if the output is not 192 // scalar). 193 output_tensor_shape = batch_shape; 194 output_tensor_shape.AppendShape(output_matrix_shape); 195 } 196 Tensor* out = nullptr; 197 // See if there is an input buffer we can reuse for this output. 198 bool reused_input = false; 199 if (EnableInputForwarding()) { 200 for (int input_idx : unused_inputs) { 201 if (context->forward_input_to_output_with_shape( 202 input_idx, output_idx, output_tensor_shape, &out)) { 203 reused_input = true; 204 unused_inputs.erase(input_idx); 205 break; 206 } 207 } 208 } 209 if (!reused_input) { 210 OP_REQUIRES_OK(context, context->allocate_output( 211 output_idx, output_tensor_shape, &out)); 212 } 213 outputs->emplace_back(out); 214 } 215 } 216 217 template <typename Scalar> 218 void LinearAlgebraOp<Scalar>::ComputeTensorSlice( 219 OpKernelContext* context, int64 matrix_index, const TensorInputs& inputs, 220 const TensorShapes& input_matrix_shapes, const TensorOutputs& outputs, 221 const TensorShapes& output_matrix_shapes) { 222 ConstMatrixMaps matrix_inputs; 223 for (size_t i = 0; i < inputs.size(); ++i) { 224 // TODO(kalakris): Handle alignment if possible. Eigen::Map is 225 // unaligned by default. 226 matrix_inputs.emplace_back( 227 inputs[i]->flat<Scalar>().data() + 228 matrix_index * input_matrix_shapes[i].num_elements(), 229 input_matrix_shapes[i].dim_size(0), input_matrix_shapes[i].dim_size(1)); 230 } 231 232 MatrixMaps matrix_outputs; 233 for (size_t i = 0; i < output_matrix_shapes.size(); ++i) { 234 // The output matrix shape may not be a matrix. 235 int num_output_rows = output_matrix_shapes[i].dims() >= 1 236 ? output_matrix_shapes[i].dim_size(0) 237 : 1; 238 int num_output_cols = output_matrix_shapes[i].dims() == 2 239 ? output_matrix_shapes[i].dim_size(1) 240 : 1; 241 matrix_outputs.emplace_back( 242 outputs[i]->flat<Scalar>().data() + 243 matrix_index * output_matrix_shapes[i].num_elements(), 244 num_output_rows, num_output_cols); 245 } 246 ComputeMatrix(context, matrix_inputs, &matrix_outputs); 247 } 248 249 // Explicitly instantiate LinearAlgebraOp for the scalar types we expect to use. 250 template class LinearAlgebraOp<float>; 251 template class LinearAlgebraOp<double>; 252 template class LinearAlgebraOp<complex64>; 253 template class LinearAlgebraOp<complex128>; 254 255 } // namespace tensorflow 256