Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/data_flow_ops.cc.
     17 
     18 #include <vector>
     19 #include "tensorflow/core/framework/op_kernel.h"
     20 #include "tensorflow/core/framework/register_types.h"
     21 #include "tensorflow/core/framework/tensor.h"
     22 #include "tensorflow/core/framework/types.h"
     23 #include "tensorflow/core/kernels/bounds_check.h"
     24 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     25 #include "tensorflow/core/util/util.h"
     26 
     27 namespace tensorflow {
     28 
     29 // Shared code that is not dependent on the type of T.  We do this to reduce
     30 // code size by not duplicating all this for all T (float, double, int32, etc.)
     31 class DynamicPartitionOp_Shared : public OpKernel {
     32  public:
     33   explicit DynamicPartitionOp_Shared(OpKernelConstruction* c) : OpKernel(c) {
     34     OP_REQUIRES_OK(c, c->GetAttr("num_partitions", &num_partitions_));
     35     //   QUESTION: It'd be nice to support DT_INT16, DT_UINT8, etc.
     36     //   to input[1].  Should we have the framework do some sort of
     37     //   integer promotion automatically, or should that be something
     38     //   that users have to do explicitly with a conversion operator
     39     //   in the graph?
     40   }
     41 
     42   void ValidateAndAllocateOutputs(OpKernelContext* c, const Tensor** data,
     43                                   const Tensor** partitions,
     44                                   OpOutputList* Tout) {
     45     OP_REQUIRES_OK(c, c->input("data", data));
     46     OP_REQUIRES_OK(c, c->input("partitions", partitions));
     47     OP_REQUIRES(
     48         c,
     49         TensorShapeUtils::StartsWith((*data)->shape(), (*partitions)->shape()),
     50         errors::InvalidArgument(
     51             "data.shape must start with partitions.shape, ",
     52             "got data.shape = ", (*data)->shape().DebugString(),
     53             ", partitions.shape = ", (*partitions)->shape().DebugString()));
     54 
     55     // Count how many occurrences of each partition id we have in partitions
     56     gtl::InlinedVector<int, 32> partition_count(num_partitions_);
     57     auto e_partitions = (*partitions)->flat<int32>();
     58     const int64 N = e_partitions.dimension(0);
     59     for (int64 i = 0; i < N; i++) {
     60       const int32 p = internal::SubtleMustCopy(e_partitions(i));
     61       OP_REQUIRES(c, FastBoundsCheck(p, num_partitions_),
     62                   errors::InvalidArgument(
     63                       "partitions", SliceDebugString((*partitions)->shape(), i),
     64                       " = ", p, " is not in [0, ", num_partitions_, ")"));
     65       partition_count[p]++;
     66     }
     67 
     68     // Allocate output tensors of the right size
     69     OP_REQUIRES_OK(c, c->output_list("outputs", Tout));
     70     for (int p = 0; p < num_partitions_; p++) {
     71       TensorShape shape;
     72       shape.AddDim(partition_count[p]);
     73       for (int i = (*partitions)->dims(); i < (*data)->dims(); i++) {
     74         shape.AddDim((*data)->dim_size(i));
     75       }
     76       Tensor* out;
     77       OP_REQUIRES_OK(c, Tout->allocate(p, shape, &out));
     78     }
     79   }
     80 
     81  protected:
     82   int num_partitions_;
     83 };
     84 
     85 template <class T>
     86 class DynamicPartitionOp : public DynamicPartitionOp_Shared {
     87  public:
     88   explicit DynamicPartitionOp(OpKernelConstruction* c)
     89       : DynamicPartitionOp_Shared(c) {}
     90   void Compute(OpKernelContext* c) override {
     91     const Tensor* data;
     92     const Tensor* partitions;
     93     OpOutputList outputs;
     94     ValidateAndAllocateOutputs(c, &data, &partitions, &outputs);
     95     if (!c->status().ok()) return;
     96     if (num_partitions_ == 0 || data->NumElements() == 0) return;
     97 
     98     auto e_partitions = partitions->flat<int32>();
     99     const int64 N = e_partitions.dimension(0);
    100     gtl::InlinedVector<int, 32> output_index(num_partitions_);
    101 
    102     if (partitions->dims() == data->dims()) {
    103       // Walk through data and copy the data to the appropriate output tensor
    104       const auto data_flat = data->flat<T>();
    105       std::vector<Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>,
    106                                    Eigen::Aligned> >
    107           out_vec;
    108       out_vec.reserve(num_partitions_);
    109       for (int p = 0; p < num_partitions_; p++) {
    110         out_vec.push_back(outputs[p]->vec<T>());
    111       }
    112       for (int64 i = 0; i < N; i++) {
    113         const int32 p = internal::SubtleMustCopy(e_partitions(i));
    114         OP_REQUIRES(
    115             c, FastBoundsCheck(p, num_partitions_),
    116             errors::InvalidArgument("indices[", i, "] is out of range"));
    117         auto oi = output_index[p];
    118         OP_REQUIRES(c, FastBoundsCheck(oi, out_vec[p].size()),
    119                     errors::InvalidArgument(
    120                         "out_vec[", p, "] size: ", out_vec[p].size(),
    121                         " is not LTE output_index[", p, "] : ", oi));
    122         out_vec[p](oi) = data_flat(i);
    123         output_index[p]++;
    124       }
    125     } else {
    126       // If data has extra dimensions, use Eigen slices
    127       std::vector<Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
    128                                    Eigen::Aligned> >
    129           out_flat;
    130       out_flat.reserve(num_partitions_);
    131       for (int p = 0; p < num_partitions_; p++) {
    132         out_flat.push_back(outputs[p]->flat_outer_dims<T>());
    133       }
    134 
    135       // Walk through data and copy the data to the appropriate output tensor
    136       const int64 slice_size = data->NumElements() / N;
    137       const auto data_flat = data->shaped<T, 2>({N, slice_size});
    138       Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, slice_size);
    139       for (int64 i = 0; i < N; i++) {
    140         // outputs[p][output_index[p]++] = data[i]
    141         const int32 p = internal::SubtleMustCopy(e_partitions(i));
    142         OP_REQUIRES(
    143             c, FastBoundsCheck(p, num_partitions_),
    144             errors::InvalidArgument("indices[", i,
    145                                     "] has been asynchronously overwitten and "
    146                                     "is no longer in range!"));
    147         auto oi = output_index[p];
    148         OP_REQUIRES(c, FastBoundsCheck(oi, out_flat[p].dimension(0)),
    149                     errors::InvalidArgument("Size of output_index: ", oi,
    150                                             " is no longer in range."));
    151         Eigen::DSizes<Eigen::DenseIndex, 2> out_indices(oi, 0);
    152         Eigen::DSizes<Eigen::DenseIndex, 2> data_indices(i, 0);
    153         out_flat[p].slice(out_indices, sizes) =
    154             data_flat.slice(data_indices, sizes);
    155         output_index[p]++;
    156       }
    157     }
    158   }
    159 };
    160 
    161 #define REGISTER_DYNAMIC_PARTITION(T)                                     \
    162   REGISTER_KERNEL_BUILDER(                                                \
    163       Name("DynamicPartition").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    164       DynamicPartitionOp<T>)
    165 
    166 TF_CALL_ALL_TYPES(REGISTER_DYNAMIC_PARTITION);
    167 #undef REGISTER_DYNAMIC_PARTITION
    168 
    169 }  // namespace tensorflow
    170