Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/kernels/linalg_ops_common.h"
     17 
     18 #include <utility>
     19 
     20 #include "third_party/eigen3/Eigen/Core"
     21 #include "tensorflow/core/framework/device_base.h"
     22 #include "tensorflow/core/framework/kernel_def_builder.h"
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/framework/tensor_shape.h"
     25 #include "tensorflow/core/lib/core/errors.h"
     26 #include "tensorflow/core/platform/logging.h"
     27 #include "tensorflow/core/platform/types.h"
     28 
     29 namespace tensorflow {
     30 
     31 // static
     32 template <typename Scalar>
     33 void LinearAlgebraOp<Scalar>::ValidateSingleMatrix(
     34     OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
     35   OP_REQUIRES(context, input_matrix_shapes.size() == 1,
     36               errors::InvalidArgument("Expected a single input matrix, got %d.",
     37                                       input_matrix_shapes.size()));
     38   OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_matrix_shapes[0]),
     39               errors::InvalidArgument("Input must be a matrix."));
     40 }
     41 
     42 // static
     43 template <typename Scalar>
     44 void LinearAlgebraOp<Scalar>::ValidateSingleSquareMatrix(
     45     OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
     46   OP_REQUIRES(context, input_matrix_shapes.size() == 1,
     47               errors::InvalidArgument("Expected a single input matrix, got %d.",
     48                                       input_matrix_shapes.size()));
     49   OP_REQUIRES(context, TensorShapeUtils::IsSquareMatrix(input_matrix_shapes[0]),
     50               errors::InvalidArgument("Input matrix must be square."));
     51 }
     52 
     53 // static
     54 template <typename Scalar>
     55 void LinearAlgebraOp<Scalar>::ValidateSolver(
     56     OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
     57   OP_REQUIRES(context, input_matrix_shapes.size() == 2,
     58               errors::InvalidArgument("Expected two input matrices, got %d.",
     59                                       input_matrix_shapes.size()));
     60   OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_matrix_shapes[0]),
     61               errors::InvalidArgument("First input (lhs) must be a matrix."));
     62   OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_matrix_shapes[1]),
     63               errors::InvalidArgument("Second input (rhs) must be a matrix."));
     64   OP_REQUIRES(
     65       context,
     66       input_matrix_shapes[0].dim_size(0) == input_matrix_shapes[1].dim_size(0),
     67       errors::InvalidArgument("Input matrix and rhs are incompatible."));
     68 }
     69 
     70 // static
     71 template <typename Scalar>
     72 void LinearAlgebraOp<Scalar>::ValidateSquareSolver(
     73     OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
     74   OP_REQUIRES(context, input_matrix_shapes.size() == 2,
     75               errors::InvalidArgument("Expected two input matrices, got %d.",
     76                                       input_matrix_shapes.size()));
     77   OP_REQUIRES(
     78       context, TensorShapeUtils::IsSquareMatrix(input_matrix_shapes[0]),
     79       errors::InvalidArgument("First input (lhs) must be a square matrix."));
     80   OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_matrix_shapes[1]),
     81               errors::InvalidArgument("Second input (rhs) must be a matrix."));
     82   OP_REQUIRES(
     83       context,
     84       input_matrix_shapes[0].dim_size(0) == input_matrix_shapes[1].dim_size(0),
     85       errors::InvalidArgument("Input matrix and rhs are incompatible."));
     86 }
     87 
     88 template <typename Scalar>
     89 void LinearAlgebraOp<Scalar>::Compute(OpKernelContext* context) {
     90   TensorInputs inputs;
     91   TensorShapes input_matrix_shapes;
     92   TensorShape batch_shape;
     93   AnalyzeInputs(context, &inputs, &input_matrix_shapes, &batch_shape);
     94 
     95   TensorShapes output_matrix_shapes;
     96   TensorOutputs outputs;
     97   PrepareOutputs(context, input_matrix_shapes, batch_shape, &outputs,
     98                  &output_matrix_shapes);
     99 
    100   // Process the individual matrix problems in parallel using a threadpool.
    101   auto shard = [this, &inputs, &input_matrix_shapes, &outputs,
    102                 &output_matrix_shapes, context](int64 begin, int64 end) {
    103     for (int64 i = begin; i < end; ++i) {
    104       ComputeTensorSlice(context, i, inputs, input_matrix_shapes, outputs,
    105                          output_matrix_shapes);
    106     }
    107   };
    108   auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
    109   Shard(worker_threads.num_threads, worker_threads.workers,
    110         batch_shape.num_elements(), GetCostPerUnit(input_matrix_shapes), shard);
    111 }
    112 
    113 template <typename Scalar>
    114 void LinearAlgebraOp<Scalar>::AnalyzeInputs(OpKernelContext* context,
    115                                             TensorInputs* inputs,
    116                                             TensorShapes* input_matrix_shapes,
    117                                             TensorShape* batch_shape) {
    118   int input_rank = -1;
    119   for (int i = 0; i < NumMatrixInputs(context); ++i) {
    120     const Tensor& in = context->input(i);
    121     if (i == 0) {
    122       input_rank = in.dims();
    123       OP_REQUIRES(
    124           context, input_rank >= 2,
    125           errors::InvalidArgument("Input tensor ", i,
    126                                   " must have rank >= 2, got ", input_rank));
    127       // If the tensor rank is greater than 2, we consider the inner-most
    128       // dimensions as matrices, and loop over all the other outer ("batch")
    129       // dimensions to compute the results.
    130       for (int dim = 0; dim < input_rank - 2; ++dim) {
    131         batch_shape->AddDim(in.dim_size(dim));
    132       }
    133     } else {
    134       // Make sure that all inputs have the same rank and outer dimensions.
    135       OP_REQUIRES(context, input_rank == in.dims(),
    136                   errors::InvalidArgument(
    137                       "All input tensors must have the same rank."));
    138       for (int dim = 0; dim < input_rank - 2; ++dim) {
    139         OP_REQUIRES(
    140             context, in.dim_size(dim) == batch_shape->dim_size(dim),
    141             errors::InvalidArgument(
    142                 "All input tensors must have the same outer dimensions."));
    143       }
    144     }
    145 
    146     const int row_dimension = input_rank - 2;
    147     const int col_dimension = input_rank - 1;
    148     const int64 num_rows = in.dim_size(row_dimension);
    149     const int64 num_cols = in.dim_size(col_dimension);
    150     input_matrix_shapes->emplace_back(
    151         std::initializer_list<int64>({num_rows, num_cols}));
    152     inputs->emplace_back(&in);
    153   }
    154   // Have the derived class validate that the inputs are as expected.
    155   ValidateInputMatrixShapes(context, *input_matrix_shapes);
    156 }
    157 
    158 template <typename Scalar>
    159 void LinearAlgebraOp<Scalar>::PrepareOutputs(
    160     OpKernelContext* context, const TensorShapes& input_matrix_shapes,
    161     const TensorShape& batch_shape, TensorOutputs* outputs,
    162     TensorShapes* output_matrix_shapes) {
    163   // Get shape for each of the matrix outputs produced by the derived class.
    164   *output_matrix_shapes = GetOutputMatrixShapes(input_matrix_shapes);
    165   const int num_outputs = output_matrix_shapes->size();
    166 
    167   // Make sure the number of op outputs is what the derived class expects.
    168   OP_REQUIRES(
    169       context, num_outputs <= context->num_outputs(),
    170       errors::Internal(
    171           "Derived class expected more outputs (%d) that the op has (%d).",
    172           num_outputs, context->num_outputs()));
    173 
    174   // Allocate outputs.
    175   std::set<int> unused_inputs;
    176   for (int input_idx = 0; input_idx < context->num_inputs(); ++input_idx) {
    177     unused_inputs.insert(input_idx);
    178   }
    179   for (int output_idx = 0; output_idx < context->num_outputs(); ++output_idx) {
    180     TensorShape output_tensor_shape({});
    181     if (output_idx < num_outputs) {
    182       // This output is used, set up output shape and allocate it.
    183       const TensorShape& output_matrix_shape =
    184           output_matrix_shapes->at(output_idx);
    185       OP_REQUIRES(context, output_matrix_shape.dims() <= 2,
    186                   errors::InvalidArgument(
    187                       "Rank of matrix output no. %d must be 0, 1 or 2, got %d.",
    188                       output_idx, output_matrix_shape.dims()));
    189 
    190       // The final output has the shape of the outer batch dimensions
    191       // concatenated with the output_matrix_shape (if the output is not
    192       // scalar).
    193       output_tensor_shape = batch_shape;
    194       output_tensor_shape.AppendShape(output_matrix_shape);
    195     }
    196     Tensor* out = nullptr;
    197     // See if there is an input buffer we can reuse for this output.
    198     bool reused_input = false;
    199     if (EnableInputForwarding()) {
    200       for (int input_idx : unused_inputs) {
    201         if (context->forward_input_to_output_with_shape(
    202                 input_idx, output_idx, output_tensor_shape, &out)) {
    203           reused_input = true;
    204           unused_inputs.erase(input_idx);
    205           break;
    206         }
    207       }
    208     }
    209     if (!reused_input) {
    210       OP_REQUIRES_OK(context, context->allocate_output(
    211                                   output_idx, output_tensor_shape, &out));
    212     }
    213     outputs->emplace_back(out);
    214   }
    215 }
    216 
    217 template <typename Scalar>
    218 void LinearAlgebraOp<Scalar>::ComputeTensorSlice(
    219     OpKernelContext* context, int64 matrix_index, const TensorInputs& inputs,
    220     const TensorShapes& input_matrix_shapes, const TensorOutputs& outputs,
    221     const TensorShapes& output_matrix_shapes) {
    222   ConstMatrixMaps matrix_inputs;
    223   for (size_t i = 0; i < inputs.size(); ++i) {
    224     // TODO(kalakris): Handle alignment if possible. Eigen::Map is
    225     // unaligned by default.
    226     matrix_inputs.emplace_back(
    227         inputs[i]->flat<Scalar>().data() +
    228             matrix_index * input_matrix_shapes[i].num_elements(),
    229         input_matrix_shapes[i].dim_size(0), input_matrix_shapes[i].dim_size(1));
    230   }
    231 
    232   MatrixMaps matrix_outputs;
    233   for (size_t i = 0; i < output_matrix_shapes.size(); ++i) {
    234     // The output matrix shape may not be a matrix.
    235     int num_output_rows = output_matrix_shapes[i].dims() >= 1
    236                               ? output_matrix_shapes[i].dim_size(0)
    237                               : 1;
    238     int num_output_cols = output_matrix_shapes[i].dims() == 2
    239                               ? output_matrix_shapes[i].dim_size(1)
    240                               : 1;
    241     matrix_outputs.emplace_back(
    242         outputs[i]->flat<Scalar>().data() +
    243             matrix_index * output_matrix_shapes[i].num_elements(),
    244         num_output_rows, num_output_cols);
    245   }
    246   ComputeMatrix(context, matrix_inputs, &matrix_outputs);
    247 }
    248 
    249 // Explicitly instantiate LinearAlgebraOp for the scalar types we expect to use.
    250 template class LinearAlgebraOp<float>;
    251 template class LinearAlgebraOp<double>;
    252 template class LinearAlgebraOp<complex64>;
    253 template class LinearAlgebraOp<complex128>;
    254 
    255 }  // namespace tensorflow
    256