Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     19 #include "tensorflow/core/framework/tensor_shape.h"
     20 
     21 namespace tensorflow {
     22 namespace {
     23 
     24 class MatrixBandPartOp : public XlaOpKernel {
     25  public:
     26   explicit MatrixBandPartOp(OpKernelConstruction* context)
     27       : XlaOpKernel(context) {}
     28 
     29   void Compile(XlaOpKernelContext* context) override {
     30     const TensorShape input_shape = context->InputShape(0);
     31     // Preliminary validation of sizes.
     32     OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
     33                 errors::InvalidArgument(
     34                     "input must be at least 2-dim, received shape: ",
     35                     input_shape.DebugString()));
     36 
     37     const TensorShape num_lower_in_shape = context->InputShape(1);
     38     OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_lower_in_shape),
     39                 errors::InvalidArgument("num_lower must be scalar, got shape ",
     40                                         num_lower_in_shape.DebugString()));
     41 
     42     const TensorShape num_upper_in_shape = context->InputShape(2);
     43     OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_upper_in_shape),
     44                 errors::InvalidArgument("num_upper must be scalar, got shape ",
     45                                         num_upper_in_shape.DebugString()));
     46 
     47     xla::ComputationBuilder* builder = context->builder();
     48     xla::ComputationDataHandle input = context->Input(0);
     49     xla::ComputationDataHandle num_lower = context->Input(1);
     50     xla::ComputationDataHandle num_upper = context->Input(2);
     51     DataType input_type = context->input_type(0);
     52     DataType index_type = context->input_type(1);
     53 
     54     TensorShape batch_shape = input_shape;
     55     batch_shape.RemoveLastDims(2);
     56     const int64 m = input_shape.dim_size(input_shape.dims() - 2);
     57     const int64 n = input_shape.dim_size(input_shape.dims() - 1);
     58 
     59     // Compute 'offset', which is how many diagonals we are above/below the
     60     // diagonal.
     61     xla::ComputationDataHandle iota_m;
     62     OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, m, &iota_m));
     63 
     64     xla::ComputationDataHandle iota_n;
     65     OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, n, &iota_n));
     66 
     67     auto offset = builder->Sub(builder->Broadcast(iota_n, {m}), iota_m,
     68                                /*broadcast_dimensions=*/{0});
     69 
     70     // If num_lower or num_upper are negative, include all lower/upper
     71     // diagonals.
     72     auto zero_index = XlaHelpers::Zero(builder, index_type);
     73     num_lower = builder->Select(
     74         builder->Lt(num_lower, zero_index),
     75         XlaHelpers::IntegerLiteral(builder, index_type, m), num_lower);
     76     num_upper = builder->Select(
     77         builder->Lt(num_upper, zero_index),
     78         XlaHelpers::IntegerLiteral(builder, index_type, n), num_upper);
     79 
     80     auto indicator = builder->And(builder->Le(builder->Neg(num_lower), offset),
     81                                   builder->Le(offset, num_upper));
     82     indicator = builder->Broadcast(indicator, batch_shape.dim_sizes());
     83 
     84     auto zero_input = XlaHelpers::Zero(builder, input_type);
     85     auto output = builder->Select(
     86         indicator, input,
     87         builder->Broadcast(zero_input, input_shape.dim_sizes()));
     88 
     89     context->SetOutput(0, output);
     90   }
     91 
     92  private:
     93   TF_DISALLOW_COPY_AND_ASSIGN(MatrixBandPartOp);
     94 };
     95 REGISTER_XLA_OP(Name("MatrixBandPart"), MatrixBandPartOp);
     96 
     97 }  // namespace
     98 }  // namespace tensorflow
     99