Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // XLA-specific Concat Ops.
     17 
     18 #include <limits>
     19 #include <vector>
     20 
     21 #include "tensorflow/compiler/tf2xla/type_util.h"
     22 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     23 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     24 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     25 #include "tensorflow/compiler/xla/literal_util.h"
     26 #include "tensorflow/core/framework/op_kernel.h"
     27 #include "tensorflow/core/framework/register_types.h"
     28 #include "tensorflow/core/framework/tensor.h"
     29 #include "tensorflow/core/framework/tensor_types.h"
     30 #include "tensorflow/core/framework/types.h"
     31 #include "tensorflow/core/kernels/bounds_check.h"
     32 #include "tensorflow/core/kernels/concat_lib.h"
     33 #include "tensorflow/core/lib/core/status.h"
     34 #include "tensorflow/core/platform/types.h"
     35 
     36 namespace tensorflow {
     37 namespace {
     38 
     39 // --------------------------------------------------------------------------
     40 class ConcatBaseOp : public XlaOpKernel {
     41  public:
     42   ConcatBaseOp(OpKernelConstruction* c, int axis_index)
     43       : XlaOpKernel(c), axis_index_(axis_index) {}
     44 
     45   void Compile(XlaOpKernelContext* ctx) override {
     46     const TensorShape concat_dim_tensor_shape = ctx->InputShape(axis_index_);
     47     OP_REQUIRES(
     48         ctx, IsLegacyScalar(concat_dim_tensor_shape),
     49         errors::InvalidArgument(
     50             "Concat dim tensor should be a scalar integer, but got shape ",
     51             concat_dim_tensor_shape.DebugString()));
     52     xla::Literal literal;
     53     OP_REQUIRES_OK(ctx, ctx->ConstantInput(axis_index_, &literal));
     54     // TODO(annarev): add a helper to support int64 input.
     55     const int32 concat_dim = literal.Get<int>({});
     56 
     57     std::vector<xla::ComputationDataHandle> values;
     58     std::vector<TensorShape> shapes;
     59     OP_REQUIRES_OK(ctx, ctx->InputList("values", &values, &shapes));
     60     const int N = values.size();
     61     const int input_dims = shapes[0].dims();
     62     const TensorShape& input_shape = shapes[0];
     63 
     64     int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim;
     65     OP_REQUIRES(ctx,
     66                 (0 <= axis && axis < input_dims) ||
     67                     (allow_legacy_scalars() && concat_dim == 0),
     68                 errors::InvalidArgument(
     69                     "ConcatOp : Expected concatenating dimensions in the range "
     70                     "[",
     71                     -input_dims, ", ", input_dims, "), but got ", concat_dim));
     72 
     73     // Make a vector holding the ComputationDataHandles for each of
     74     // the inputs that has non-zero elements.
     75     std::vector<xla::ComputationDataHandle> input_data;
     76     int output_concat_dim = 0;
     77     const bool input_is_scalar = IsLegacyScalar(input_shape);
     78     for (int i = 0; i < N; ++i) {
     79       xla::ComputationDataHandle handle = values[i];
     80       const TensorShape& in_shape = shapes[i];
     81       const bool in_is_scalar = IsLegacyScalar(in_shape);
     82       OP_REQUIRES(
     83           ctx,
     84           in_shape.dims() == input_dims || (input_is_scalar && in_is_scalar),
     85           errors::InvalidArgument(
     86               "ConcatOp : Ranks of all input tensors should match: shape[0] = ",
     87               input_shape.DebugString(), " vs. shape[", i,
     88               "] = ", in_shape.DebugString()));
     89       if (in_shape.dims() == 0) {
     90         // Inputs that come in as scalars must be reshaped to 1-vectors.
     91         input_data.push_back(ctx->builder()->Reshape(handle, {1}));
     92       } else {
     93         input_data.push_back(handle);
     94       }
     95       output_concat_dim += in_shape.dims() > 0 ? in_shape.dim_size(axis) : 1;
     96     }
     97 
     98     VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis;
     99     ctx->SetOutput(0, ctx->builder()->ConcatInDim(input_data, axis));
    100   }
    101 
    102  private:
    103   int axis_index_;
    104 };
    105 
    106 class ConcatOp : public ConcatBaseOp {
    107  public:
    108   explicit ConcatOp(OpKernelConstruction* c)
    109       : ConcatBaseOp(c, /* axis_index */ 0) {}
    110 };
    111 
    112 // ConcatV2 operation is the same as Concat except 'concat_dim'
    113 // is the last input instead of the first and renamed to 'axis'.
    114 class ConcatV2Op : public ConcatBaseOp {
    115  public:
    116   explicit ConcatV2Op(OpKernelConstruction* c)
    117       : ConcatBaseOp(c, /* axis_index */ c->num_inputs() - 1) {}
    118 };
    119 
    120 REGISTER_XLA_OP(Name("Concat").CompileTimeConstInput("concat_dim"), ConcatOp);
    121 REGISTER_XLA_OP(Name("ConcatV2")
    122                     .TypeConstraint("Tidx", DT_INT32)
    123                     .CompileTimeConstInput("axis"),
    124                 ConcatV2Op);
    125 
    126 class ConcatOffsetOp : public XlaOpKernel {
    127  public:
    128   explicit ConcatOffsetOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
    129 
    130   void Compile(XlaOpKernelContext* ctx) override {
    131     const TensorShape concat_dim_shape = ctx->InputShape(0);
    132     OP_REQUIRES(
    133         ctx, IsLegacyScalar(concat_dim_shape),
    134         errors::InvalidArgument(
    135             "Concat dim tensor should be a scalar integer, but got shape ",
    136             concat_dim_shape.DebugString()));
    137     for (int i = 1; i < ctx->num_inputs(); ++i) {
    138       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ctx->InputShape(i)),
    139                   errors::InvalidArgument("input ", i,
    140                                           " should be a vector, but got shape ",
    141                                           ctx->InputShape(i).DebugString()));
    142     }
    143     // Suppose a Concat() op needs to Concatenate N tensors, each of
    144     // which has the same number of dimensions.  Their shapes match
    145     // except the concat dimension.
    146     //
    147     // E.g., say, we want to concatenate 3 tensors in the 2nd
    148     // dimension, and their shapes are:
    149     //
    150     //  [2, 2, 5, 7]
    151     //  [2, 3, 5, 7]
    152     //  [2, 4, 5, 7]
    153     //
    154     // Here, N=3, cdim=1, dims=4. The concatenated tensor has shape
    155     // [2,9,5,7]. We will compute the cumulative sum along the 2nd
    156     // dimension to figure out each input's offset in the concatenated
    157     // output:
    158     //  [0, 0, 0, 0]
    159     //  [0, 2, 0, 0]
    160     //  [0, 5, 0, 0]
    161     const int32 N = ctx->num_inputs() - 1;
    162     const TensorShape inp0_shape = ctx->InputShape(1);
    163     xla::Literal inp0_literal;
    164     OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &inp0_literal));
    165     const int64 dims = inp0_shape.num_elements();
    166 
    167     xla::Literal concat_dim_literal;
    168     OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &concat_dim_literal));
    169     const int64 cdim = concat_dim_literal.Get<int>({});
    170 
    171     VLOG(1) << "ConcatOffset " << cdim << "," << dims;
    172     int32 axis = cdim < 0 ? cdim + dims : cdim;
    173     OP_REQUIRES(ctx, FastBoundsCheck(axis, dims),
    174                 errors::InvalidArgument("Concat dim is out of range: ", axis,
    175                                         " vs. ", dims));
    176     int32 offset = 0;
    177     for (int i = 0; i < N; ++i) {
    178       const TensorShape inp_shape = ctx->InputShape(1 + i);
    179       OP_REQUIRES(ctx, dims == inp_shape.num_elements(),
    180                   errors::InvalidArgument("input ", i, " should contain ", dims,
    181                                           " elements, but got ",
    182                                           inp_shape.num_elements()));
    183       xla::Literal inp_literal;
    184       OP_REQUIRES_OK(ctx, ctx->ConstantInput(1 + i, &inp_literal));
    185 
    186       Tensor out_constant(DT_INT32, TensorShape({dims}));
    187       auto out_vec = out_constant.vec<int32>();
    188       for (int64 j = 0; j < dims; ++j) {
    189         if (j == axis) {
    190           out_vec(j) = offset;
    191           offset += inp_literal.Get<int>({j});
    192         } else {
    193           const int32 inp0_element = inp0_literal.Get<int>({j});
    194           const int32 inp_element = inp_literal.Get<int>({j});
    195           OP_REQUIRES(ctx, (inp0_element == inp_element),
    196                       errors::InvalidArgument("input[", i, ",", j,
    197                                               "] mismatch: ", inp0_element,
    198                                               " vs. ", inp_element));
    199           out_vec(j) = 0;
    200         }
    201       }
    202 
    203       ctx->SetConstantOutput(i, out_constant);
    204     }
    205   }
    206 };
    207 
    208 REGISTER_XLA_OP(Name("ConcatOffset")
    209                     .CompileTimeConstInput("concat_dim")
    210                     .CompileTimeConstInput("shape"),
    211                 ConcatOffsetOp);
    212 
    213 }  // namespace
    214 }  // namespace tensorflow
    215