Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/lib/util.h"
     17 #include "tensorflow/compiler/tf2xla/type_util.h"
     18 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     19 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     21 #include "tensorflow/compiler/xla/util.h"
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 
     24 namespace tensorflow {
     25 namespace {
     26 
     27 // Create a diagonal / batch diagonal matrix with 'input' on the diagonal.
     28 xla::StatusOr<xla::ComputationDataHandle> CreateDiagonal(
     29     const xla::ComputationDataHandle& input, int64 last_dim_size,
     30     tensorflow::gtl::ArraySlice<int64> other_dims, XlaOpKernelContext* ctx,
     31     xla::ComputationBuilder* builder) {
     32   // Create two matrices that have the following forms, and compare them:
     33   //
     34   // [[0, 0, 0, 0]            [[0, 1, 2, 3]
     35   //  [1, 1, 1, 1]             [0, 1, 2, 3]
     36   //  [2, 2, 2, 2]             [0, 1, 2, 3]
     37   //  [3, 3, 3, 3]]            [0, 1, 2, 3]]
     38   //
     39   // This produces a predicate matrix of the right size, with "true" on the
     40   // diagonal.
     41   xla::ComputationDataHandle iota;
     42   TF_RETURN_IF_ERROR(
     43       XlaHelpers::Iota(builder, DataType::DT_INT32, last_dim_size, &iota));
     44   xla::ComputationDataHandle iota_broadcast =
     45       builder->Broadcast(iota, {last_dim_size});
     46   xla::ComputationDataHandle mask = builder->Eq(iota_broadcast, iota, {0});
     47 
     48   // If this is a batched diagonal, broadcast the mask across the other
     49   // dimensions.
     50   if (!other_dims.empty()) {
     51     mask = builder->Broadcast(mask, other_dims);
     52   }
     53 
     54   // Broadcast the input, and then use the mask computed above to select the
     55   // diagonal:
     56   // e.g, in 2D:
     57   //         [[t, f, f]    [[1, 1, 1]    [[0, 0, 0]      [[1, 0, 0]
     58   // select(  [f, t, f]  ,  [4, 4, 4]  ,  [0, 0, 0]  ) =  [0, 4, 0]
     59   //          [f, f, t]]    [9, 9, 9]]    [0, 0, 0]]      [0, 0, 9]]
     60   //
     61   // Broadcasting the input is less-than-trivial, since we need to broadcast
     62   // into a "middle" dimension. We can do this with a reshape + implicit
     63   // broadcast.
     64   // TODO(b/30112114): Replace with in-dim broadcast when those are supported.
     65   std::vector<int64> broadcast_dims(other_dims.begin(), other_dims.end());
     66   broadcast_dims.push_back(1LL);
     67   broadcast_dims.push_back(last_dim_size);
     68   xla::ComputationDataHandle input_broadcast =
     69       builder->Reshape(input, broadcast_dims);
     70 
     71   broadcast_dims[broadcast_dims.size() - 2] = last_dim_size;
     72   xla::PrimitiveType element_type;
     73   TF_RETURN_IF_ERROR(
     74       DataTypeToPrimitiveType(ctx->input_type(0), &element_type));
     75   auto broadcast_shape =
     76       xla::ShapeUtil::MakeShape(element_type, broadcast_dims);
     77   xla::ComputationDataHandle zeros = Zeros(builder, broadcast_shape);
     78 
     79   input_broadcast = builder->Add(input_broadcast, zeros);
     80   return builder->Select(mask, input_broadcast, zeros);
     81 }
     82 
     83 class DiagOp : public XlaOpKernel {
     84  public:
     85   explicit DiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
     86 
     87   void Compile(XlaOpKernelContext* ctx) override {
     88     xla::ComputationBuilder* builder = ctx->builder();
     89 
     90     OP_REQUIRES(ctx, ctx->num_inputs() >= 1,
     91                 errors::InvalidArgument("Diag op must have at an input"));
     92     const TensorShape input_shape = ctx->InputShape(0);
     93 
     94     auto dims = input_shape.dim_sizes();
     95     OP_REQUIRES(ctx, !dims.empty(),
     96                 errors::InvalidArgument("Expected 1 <= dims, got shape ",
     97                                         input_shape.DebugString()));
     98 
     99     xla::ComputationDataHandle input = ctx->Input(0);
    100 
    101     // Picture:
    102     // tf.diag([1, 2, 3, 4]) ==> [[1, 0, 0, 0]
    103     //                            [0, 2, 0, 0]
    104     //                            [0, 0, 3, 0]
    105     //                            [0, 0, 0, 4]]
    106 
    107     // Flattens the input to 1D.
    108     int64 size = input_shape.num_elements();
    109     input = builder->Reshape(input, {size});
    110 
    111     // Create an R2 with the R1 diagonal.
    112     auto diag_or_status =
    113         CreateDiagonal(input, size, /*other_dims=*/{}, ctx, builder);
    114     OP_REQUIRES_OK(ctx, diag_or_status.status());
    115     xla::ComputationDataHandle diag = diag_or_status.ValueOrDie();
    116 
    117     // Reshapes to the final shape.
    118     std::vector<int64> new_dims(dims.size() * 2);
    119     std::copy(dims.begin(), dims.end(), new_dims.begin());
    120     std::copy(dims.begin(), dims.end(), new_dims.begin() + dims.size());
    121     diag = builder->Reshape(diag, new_dims);
    122 
    123     ctx->SetOutput(0, diag);
    124   }
    125 };
    126 
    127 REGISTER_XLA_OP(Name("Diag"), DiagOp);
    128 
    129 class DiagPartOp : public XlaOpKernel {
    130  public:
    131   explicit DiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
    132 
    133   void Compile(XlaOpKernelContext* ctx) override {
    134     xla::ComputationBuilder* builder = ctx->builder();
    135 
    136     const TensorShape input_shape = ctx->InputShape(0);
    137     auto dims = input_shape.dim_sizes();
    138 
    139     int num_dims = dims.size();
    140     const int out_dims = num_dims / 2;
    141 
    142     OP_REQUIRES(ctx, 2 <= num_dims,
    143                 errors::InvalidArgument("Expected 2 <= dims, got shape ",
    144                                         input_shape.DebugString()));
    145     OP_REQUIRES(ctx, num_dims % 2 == 0,
    146                 errors::InvalidArgument("The input tensor must have even rank; "
    147                                         "got shape ",
    148                                         input_shape.DebugString()));
    149     int64 new_size = 1;
    150     std::vector<int64> new_dims;
    151     for (int i = 0; i < out_dims; i++) {
    152       OP_REQUIRES(
    153           ctx, dims[i] == dims[i + out_dims],
    154           errors::InvalidArgument("Invalid shape ", input_shape.DebugString(),
    155                                   ": dimensions ", i, " and ", i + out_dims,
    156                                   " do not match."));
    157       new_size *= dims[i];
    158       new_dims.push_back(dims[i]);
    159     }
    160 
    161     xla::ComputationDataHandle diag = ctx->Input(0);
    162 
    163     // TODO(b/30878775): use Slice with strides when supported, in place of
    164     // the Pad -> Reshape -> Slice.
    165 
    166     // Picture:
    167     // [[1, 0, 0, 0]  pad and reshape to [[1, 0, 0, 0, 0],
    168     //  [0, 2, 0, 0]  =================>  [2, 0, 0, 0, 0],
    169     //  [0, 0, 3, 0]                      [3, 0, 0, 0, 0],
    170     //  [0, 0, 0, 4]]                     [4, 0, 0, 0, 0]]
    171     // and then slice out the first column.
    172 
    173     // Flattens the input to 1D.
    174     int64 size = input_shape.num_elements();
    175     diag = builder->Reshape(diag, {size});
    176 
    177     // Adds padding after the last element of 'new_size'.
    178     xla::PaddingConfig config;
    179     auto* dim = config.add_dimensions();
    180     dim->set_edge_padding_high(new_size);
    181     auto zero = XlaHelpers::Zero(builder, input_type(0));
    182     diag = builder->Pad(diag, zero, config);
    183 
    184     // Reshapes so the diagonal is now in the first column.
    185     diag = builder->Reshape(diag, {new_size, new_size + 1});
    186 
    187     // Slices out the first column and reshapes to the final shape.
    188     diag = builder->Slice(diag, {0, 0}, {new_size, 1}, {1, 1});
    189     diag = builder->Reshape(diag, new_dims);
    190 
    191     ctx->SetOutput(0, diag);
    192   }
    193 };
    194 
    195 REGISTER_XLA_OP(Name("DiagPart"), DiagPartOp);
    196 
    197 class MatrixDiagOp : public XlaOpKernel {
    198  public:
    199   explicit MatrixDiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
    200 
    201   void Compile(XlaOpKernelContext* ctx) override {
    202     xla::ComputationBuilder* builder = ctx->builder();
    203 
    204     OP_REQUIRES(ctx, ctx->num_inputs() >= 1,
    205                 errors::InvalidArgument("MatrixDiag op must have at an input"));
    206     const TensorShape input_shape = ctx->InputShape(0);
    207 
    208     auto dims = input_shape.dim_sizes();
    209     OP_REQUIRES(ctx, !dims.empty(),
    210                 errors::InvalidArgument("Expected 1 <= dims, got shape ",
    211                                         input_shape.DebugString()));
    212 
    213     xla::ComputationDataHandle diag = ctx->Input(0);
    214 
    215     int last_dim = dims.size() - 1;
    216     int64 last_dim_size = input_shape.dim_size(last_dim);
    217     tensorflow::gtl::ArraySlice<int64> other_dims(dims);
    218     other_dims.pop_back();
    219 
    220     auto diag_or_status =
    221         CreateDiagonal(diag, last_dim_size, other_dims, ctx, builder);
    222     OP_REQUIRES_OK(ctx, diag_or_status.status());
    223     diag = diag_or_status.ValueOrDie();
    224     ctx->SetOutput(0, diag);
    225   }
    226 };
    227 
    228 REGISTER_XLA_OP(Name("MatrixDiag"), MatrixDiagOp);
    229 
    230 class MatrixDiagPartOp : public XlaOpKernel {
    231  public:
    232   explicit MatrixDiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
    233 
    234   void Compile(XlaOpKernelContext* ctx) override {
    235     xla::ComputationBuilder* builder = ctx->builder();
    236 
    237     const TensorShape input_shape = ctx->InputShape(0);
    238     auto dims = input_shape.dim_sizes();
    239 
    240     OP_REQUIRES(ctx, 2 <= dims.size(),
    241                 errors::InvalidArgument("Expected 2 <= dims, got shape ",
    242                                         input_shape.DebugString()));
    243 
    244     xla::ComputationDataHandle diag = ctx->Input(0);
    245 
    246     int last_dim = dims.size() - 1;
    247     int64 last_dim_size = dims[last_dim];
    248 
    249     // The smaller of the last two dimension sizes.
    250     int64 smaller_dim_size = std::min(dims[last_dim - 1], dims[last_dim]);
    251 
    252     // TODO(b/30878775): use Slice with strides when supported, in place of
    253     // the Pad -> Reshape -> Slice.
    254 
    255     // Picture: for each 2D matrix in the tensor's last two dimensions:
    256     // [[1, 0, 0, 0]  pad and reshape to [[1, 0, 0, 0, 0],
    257     //  [0, 2, 0, 0]  =================>  [2, 0, 0, 0, 0],
    258     //  [0, 0, 3, 0]]                     [3, 0, 0, 0, 0],
    259     // and then slice out the first column.
    260     //
    261     // Another example, with tall and narrow input.
    262     // [[1, 0]  pad and reshape to [[1, 0, 0],
    263     //  [0, 2]  =================>  [2, 0, 0]]
    264     //  [0, 0]
    265     //  [0, 0]]
    266 
    267     // Collapses the last two dimensions.
    268     std::vector<int64> flattened_dims(dims.begin(), dims.end() - 1);
    269     flattened_dims.back() *= dims.back();
    270     diag = builder->Reshape(diag, flattened_dims);
    271 
    272     // Slices or pads the last dimension to 'target_size'.
    273     int64 actual_size = flattened_dims.back();
    274     int64 target_size = smaller_dim_size * (last_dim_size + 1);
    275     if (actual_size < target_size) {
    276       xla::PaddingConfig config =
    277           xla::MakeNoPaddingConfig(flattened_dims.size());
    278       auto* dim = config.mutable_dimensions(flattened_dims.size() - 1);
    279       dim->set_edge_padding_high(target_size - actual_size);
    280       auto zero = XlaHelpers::Zero(builder, input_type(0));
    281       diag = builder->Pad(diag, zero, config);
    282     } else if (actual_size > target_size) {
    283       std::vector<int64> start(flattened_dims.size(), 0);
    284       std::vector<int64> limits(flattened_dims.begin(), flattened_dims.end());
    285       std::vector<int64> strides(flattened_dims.size(), 1);
    286       limits[flattened_dims.size() - 1] = target_size;
    287       diag = builder->Slice(diag, start, limits, strides);
    288     }
    289 
    290     // Reshape so the target values are in the first position of the last
    291     // dimension.
    292     std::vector<int64> unflattened_dims(dims.begin(), dims.end());
    293     dims[last_dim - 1] = smaller_dim_size;
    294     dims[last_dim] = last_dim_size + 1;
    295     diag = builder->Reshape(diag, dims);
    296 
    297     // Slices out the first column and reshapes to the final shape.
    298     std::vector<int64> start(dims.size(), 0);
    299     std::vector<int64> limits(dims.begin(), dims.end());
    300     std::vector<int64> strides(dims.size(), 1);
    301     limits[last_dim] = 1;
    302     diag = builder->Slice(diag, start, limits, strides);
    303 
    304     // Collapses away the last dimension.
    305     dims.pop_back();
    306     diag = builder->Reshape(diag, dims);
    307 
    308     ctx->SetOutput(0, diag);
    309   }
    310 };
    311 
    312 REGISTER_XLA_OP(Name("MatrixDiagPart"), MatrixDiagPartOp);
    313 
    314 }  // namespace
    315 }  // namespace tensorflow
    316