1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/data_flow_ops.cc. 17 18 #include <vector> 19 #include "tensorflow/core/framework/op_kernel.h" 20 #include "tensorflow/core/framework/register_types.h" 21 #include "tensorflow/core/framework/tensor.h" 22 #include "tensorflow/core/framework/types.h" 23 #include "tensorflow/core/kernels/bounds_check.h" 24 #include "tensorflow/core/lib/gtl/inlined_vector.h" 25 #include "tensorflow/core/util/util.h" 26 27 namespace tensorflow { 28 29 // Shared code that is not dependent on the type of T. We do this to reduce 30 // code size by not duplicating all this for all T (float, double, int32, etc.) 31 class DynamicPartitionOp_Shared : public OpKernel { 32 public: 33 explicit DynamicPartitionOp_Shared(OpKernelConstruction* c) : OpKernel(c) { 34 OP_REQUIRES_OK(c, c->GetAttr("num_partitions", &num_partitions_)); 35 // QUESTION: It'd be nice to support DT_INT16, DT_UINT8, etc. 36 // to input[1]. Should we have the framework do some sort of 37 // integer promotion automatically, or should that be something 38 // that users have to do explicitly with a conversion operator 39 // in the graph? 40 } 41 42 void ValidateAndAllocateOutputs(OpKernelContext* c, const Tensor** data, 43 const Tensor** partitions, 44 OpOutputList* Tout) { 45 OP_REQUIRES_OK(c, c->input("data", data)); 46 OP_REQUIRES_OK(c, c->input("partitions", partitions)); 47 OP_REQUIRES( 48 c, 49 TensorShapeUtils::StartsWith((*data)->shape(), (*partitions)->shape()), 50 errors::InvalidArgument( 51 "data.shape must start with partitions.shape, ", 52 "got data.shape = ", (*data)->shape().DebugString(), 53 ", partitions.shape = ", (*partitions)->shape().DebugString())); 54 55 // Count how many occurrences of each partition id we have in partitions 56 gtl::InlinedVector<int, 32> partition_count(num_partitions_); 57 auto e_partitions = (*partitions)->flat<int32>(); 58 const int64 N = e_partitions.dimension(0); 59 for (int64 i = 0; i < N; i++) { 60 const int32 p = internal::SubtleMustCopy(e_partitions(i)); 61 OP_REQUIRES(c, FastBoundsCheck(p, num_partitions_), 62 errors::InvalidArgument( 63 "partitions", SliceDebugString((*partitions)->shape(), i), 64 " = ", p, " is not in [0, ", num_partitions_, ")")); 65 partition_count[p]++; 66 } 67 68 // Allocate output tensors of the right size 69 OP_REQUIRES_OK(c, c->output_list("outputs", Tout)); 70 for (int p = 0; p < num_partitions_; p++) { 71 TensorShape shape; 72 shape.AddDim(partition_count[p]); 73 for (int i = (*partitions)->dims(); i < (*data)->dims(); i++) { 74 shape.AddDim((*data)->dim_size(i)); 75 } 76 Tensor* out; 77 OP_REQUIRES_OK(c, Tout->allocate(p, shape, &out)); 78 } 79 } 80 81 protected: 82 int num_partitions_; 83 }; 84 85 template <class T> 86 class DynamicPartitionOp : public DynamicPartitionOp_Shared { 87 public: 88 explicit DynamicPartitionOp(OpKernelConstruction* c) 89 : DynamicPartitionOp_Shared(c) {} 90 void Compute(OpKernelContext* c) override { 91 const Tensor* data; 92 const Tensor* partitions; 93 OpOutputList outputs; 94 ValidateAndAllocateOutputs(c, &data, &partitions, &outputs); 95 if (!c->status().ok()) return; 96 if (num_partitions_ == 0 || data->NumElements() == 0) return; 97 98 auto e_partitions = partitions->flat<int32>(); 99 const int64 N = e_partitions.dimension(0); 100 gtl::InlinedVector<int, 32> output_index(num_partitions_); 101 102 if (partitions->dims() == data->dims()) { 103 // Walk through data and copy the data to the appropriate output tensor 104 const auto data_flat = data->flat<T>(); 105 std::vector<Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>, 106 Eigen::Aligned> > 107 out_vec; 108 out_vec.reserve(num_partitions_); 109 for (int p = 0; p < num_partitions_; p++) { 110 out_vec.push_back(outputs[p]->vec<T>()); 111 } 112 for (int64 i = 0; i < N; i++) { 113 const int32 p = internal::SubtleMustCopy(e_partitions(i)); 114 OP_REQUIRES( 115 c, FastBoundsCheck(p, num_partitions_), 116 errors::InvalidArgument("indices[", i, "] is out of range")); 117 auto oi = output_index[p]; 118 OP_REQUIRES(c, FastBoundsCheck(oi, out_vec[p].size()), 119 errors::InvalidArgument( 120 "out_vec[", p, "] size: ", out_vec[p].size(), 121 " is not LTE output_index[", p, "] : ", oi)); 122 out_vec[p](oi) = data_flat(i); 123 output_index[p]++; 124 } 125 } else { 126 // If data has extra dimensions, use Eigen slices 127 std::vector<Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, 128 Eigen::Aligned> > 129 out_flat; 130 out_flat.reserve(num_partitions_); 131 for (int p = 0; p < num_partitions_; p++) { 132 out_flat.push_back(outputs[p]->flat_outer_dims<T>()); 133 } 134 135 // Walk through data and copy the data to the appropriate output tensor 136 const int64 slice_size = data->NumElements() / N; 137 const auto data_flat = data->shaped<T, 2>({N, slice_size}); 138 Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, slice_size); 139 for (int64 i = 0; i < N; i++) { 140 // outputs[p][output_index[p]++] = data[i] 141 const int32 p = internal::SubtleMustCopy(e_partitions(i)); 142 OP_REQUIRES( 143 c, FastBoundsCheck(p, num_partitions_), 144 errors::InvalidArgument("indices[", i, 145 "] has been asynchronously overwitten and " 146 "is no longer in range!")); 147 auto oi = output_index[p]; 148 OP_REQUIRES(c, FastBoundsCheck(oi, out_flat[p].dimension(0)), 149 errors::InvalidArgument("Size of output_index: ", oi, 150 " is no longer in range.")); 151 Eigen::DSizes<Eigen::DenseIndex, 2> out_indices(oi, 0); 152 Eigen::DSizes<Eigen::DenseIndex, 2> data_indices(i, 0); 153 out_flat[p].slice(out_indices, sizes) = 154 data_flat.slice(data_indices, sizes); 155 output_index[p]++; 156 } 157 } 158 } 159 }; 160 161 #define REGISTER_DYNAMIC_PARTITION(T) \ 162 REGISTER_KERNEL_BUILDER( \ 163 Name("DynamicPartition").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 164 DynamicPartitionOp<T>) 165 166 TF_CALL_ALL_TYPES(REGISTER_DYNAMIC_PARTITION); 167 #undef REGISTER_DYNAMIC_PARTITION 168 169 } // namespace tensorflow 170