Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // XLA-specific Ops for split.
     17 
     18 #include "tensorflow/compiler/tf2xla/type_util.h"
     19 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     22 #include "tensorflow/compiler/xla/literal_util.h"
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/framework/register_types.h"
     25 #include "tensorflow/core/framework/tensor.h"
     26 #include "tensorflow/core/framework/tensor_shape.h"
     27 
     28 namespace tensorflow {
     29 namespace {
     30 
     31 class SplitOp : public XlaOpKernel {
     32  public:
     33   explicit SplitOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
     34 
     35   void Compile(XlaOpKernelContext* ctx) override {
     36     const int32 num_split = num_outputs();
     37     const TensorShape index_shape = ctx->InputShape(0);
     38     const TensorShape input_shape = ctx->InputShape(1);
     39 
     40     xla::Literal literal_index;
     41     OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal_index));
     42 
     43     int32 split_dim_orig;
     44     if (index_shape.dims() == 0) {
     45       split_dim_orig = literal_index.Get<int>({});
     46     } else {
     47       OP_REQUIRES(
     48           ctx, index_shape.dims() == 1,
     49           errors::InvalidArgument("split_index input to Split Op must be a "
     50                                   "scalar or a vector with 1 element"));
     51       OP_REQUIRES(
     52           ctx, index_shape.dim_size(0) == 1,
     53           errors::InvalidArgument("split_index input to Split Op must be a "
     54                                   "scalar or a vector with 1 element"));
     55       split_dim_orig = literal_index.Get<int>({0});
     56     }
     57     int32 split_dim = split_dim_orig < 0 ? split_dim_orig + input_shape.dims()
     58                                          : split_dim_orig;
     59     OP_REQUIRES(ctx, 0 <= split_dim && split_dim < input_shape.dims(),
     60                 errors::InvalidArgument("-input rank(-", input_shape.dims(),
     61                                         ") <= split_dim < input rank (",
     62                                         input_shape.dims(), "), but got ",
     63                                         split_dim_orig));
     64 
     65     OP_REQUIRES(
     66         ctx, num_split > 0,
     67         errors::InvalidArgument(
     68             "Number of ways to split should be > 0, but got ", num_split));
     69 
     70     OP_REQUIRES(
     71         ctx, input_shape.dim_size(split_dim) % num_split == 0,
     72         errors::InvalidArgument(
     73             "Number of ways to split should evenly divide the split "
     74             "dimension, but got split_dim ",
     75             split_dim_orig, " (size = ", input_shape.dim_size(split_dim), ") ",
     76             "and num_split ", num_split));
     77 
     78     // All the slices are the same size: this is the size along the
     79     // split dimension.
     80     const int32 slice_size = input_shape.dim_size(split_dim) / num_split;
     81 
     82     // The vectors we will use to define the slice. The entry for the
     83     // split dimensions varies for each output.
     84     std::vector<int64> begin(input_shape.dims(), 0);
     85     std::vector<int64> limits(input_shape.dims());
     86     std::vector<int64> strides(input_shape.dims(), 1);
     87     for (int i = 0; i < input_shape.dims(); ++i) {
     88       // Initially set up the limits to be the full size of the input:
     89       // the split dimension is filled in below.
     90       int64 dim = input_shape.dim_size(i);
     91       limits[i] = dim;
     92     }
     93 
     94     auto input = ctx->Input(1);
     95 
     96     // Create each of the outputs.
     97     for (int i = 0; i < num_split; ++i) {
     98       // Slice out the ith split from the split dimension.
     99       begin[split_dim] = i * slice_size;
    100       limits[split_dim] = (i + 1) * slice_size;
    101       ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits, strides));
    102     }
    103   }
    104 };
    105 
    106 REGISTER_XLA_OP(Name("Split").CompileTimeConstInput("split_dim"), SplitOp);
    107 
    108 class SplitVOp : public XlaOpKernel {
    109  public:
    110   explicit SplitVOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
    111 
    112   void Compile(XlaOpKernelContext* ctx) override {
    113     const int32 num_split = num_outputs();
    114     const TensorShape index_shape = ctx->InputShape(2);
    115     xla::Literal literal_index;
    116     OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &literal_index));
    117 
    118     int32 split_dim;
    119     OP_REQUIRES(ctx, index_shape.dims() == 0,
    120                 errors::InvalidArgument("split_dim input to Split Op must be a "
    121                                         "scalar"));
    122     split_dim = literal_index.Get<int>({});
    123 
    124     xla::ComputationDataHandle input = ctx->Input(0);
    125     const TensorShape input_shape = ctx->InputShape(0);
    126 
    127     OP_REQUIRES(ctx, input_shape.dims() > 0,
    128                 errors::InvalidArgument("Can't split a 0 dimensional input"));
    129 
    130     OP_REQUIRES(
    131         ctx, 0 <= split_dim && split_dim < input_shape.dims(),
    132         errors::InvalidArgument("0 <= split_dim < number of input dimensions (",
    133                                 input_shape.dims(), "), but got ", split_dim));
    134 
    135     OP_REQUIRES(
    136         ctx, num_split > 0,
    137         errors::InvalidArgument(
    138             "Number of ways to split should be > 0, but got ", num_split));
    139 
    140     // check that sizes are correct
    141     int total_split_size = 0;
    142     int neg_one_dim = -1;
    143     std::vector<int64> split_sizes_vec(num_split, -1);
    144     const TensorShape split_size_shape = ctx->InputShape(1);
    145     OP_REQUIRES(ctx,
    146                 split_size_shape.dims() == 1 &&
    147                     split_size_shape.num_elements() == num_split,
    148                 errors::InvalidArgument(
    149                     "shape of tensor describing "
    150                     " the output must have dimension 1 and the same "
    151                     " number of elements as the output. Got ",
    152                     split_size_shape.dims(), "-D and ",
    153                     split_size_shape.num_elements(), " elements"));
    154     // get the dimension of this split
    155     xla::Literal split_size_literal;
    156     OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &split_size_literal));
    157 
    158     for (int i = 0; i < num_split; ++i) {
    159       int slice_size;
    160       slice_size = split_size_literal.Get<int>({i});
    161       if (slice_size == -1) {
    162         OP_REQUIRES(
    163             ctx, neg_one_dim == -1,
    164             errors::InvalidArgument("Only one dimensions can have a value of"
    165                                     "-1. Second one found at dimension ",
    166                                     i));
    167         neg_one_dim = i;
    168       } else {
    169         split_sizes_vec[i] = slice_size;
    170         total_split_size += slice_size;
    171       }
    172     }
    173 
    174     OP_REQUIRES(
    175         ctx,
    176         (neg_one_dim == -1 &&
    177          total_split_size == input_shape.dim_size(split_dim)) ||
    178             (neg_one_dim >= 0 &&
    179              total_split_size <= input_shape.dim_size(split_dim)),
    180         errors::InvalidArgument("Determined shape must either match "
    181                                 "input shape along split_dim exactly if "
    182                                 "fully specified, or be less than the size of "
    183                                 "the input along split_dim if not fully "
    184                                 "specified.  Got: ",
    185                                 total_split_size));
    186 
    187     if (neg_one_dim >= 0) {
    188       split_sizes_vec[neg_one_dim] =
    189           input_shape.dim_size(split_dim) - total_split_size;
    190     }
    191 
    192     // The vectors we will use to define the slice. The entry for the
    193     // split dimensions varies for each output.
    194     std::vector<int64> begin(input_shape.dims(), 0);
    195     auto dim_sizes = input_shape.dim_sizes();
    196     std::vector<int64> limits(dim_sizes.begin(), dim_sizes.end());
    197     std::vector<int64> strides(input_shape.dims(), 1);
    198     for (int i = 0; i < num_split; ++i) {
    199       TensorShape output_shape(input_shape);
    200       int slice_size = split_sizes_vec[i];
    201       output_shape.set_dim(split_dim, slice_size);
    202 
    203       // Slice out the ith split from the split dimension.
    204       limits[split_dim] = begin[split_dim] + slice_size;
    205       ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits, strides));
    206       begin[split_dim] = limits[split_dim];
    207     }
    208   }
    209 };
    210 
    211 REGISTER_XLA_OP(Name("SplitV")
    212                     .CompileTimeConstInput("split_dim")
    213                     .CompileTimeConstInput("size_splits"),
    214                 SplitVOp);
    215 
    216 }  // namespace
    217 }  // namespace tensorflow
    218