1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/common_shape_fns.h" 17 #include "tensorflow/core/framework/op.h" 18 #include "tensorflow/core/framework/op_kernel.h" 19 #include "tensorflow/core/framework/register_types.h" 20 #include "tensorflow/core/framework/register_types_traits.h" 21 #include "tensorflow/core/framework/shape_inference.h" 22 #include "tensorflow/core/lib/gtl/array_slice.h" 23 #include "tensorflow/core/platform/types.h" 24 #include "tensorflow/core/util/work_sharder.h" 25 26 namespace tensorflow { 27 28 #define EIGEN_USE_THREADS 29 using CPUDevice = Eigen::ThreadPoolDevice; 30 31 // dim_size - the size of each dimension 32 // dim_range - the number of indices over in the flattened tensor 33 // you need to skip in order to make it over from one side of a dimension 34 // to the other. Used to make the shifts wrap around after a threshold. 35 // threshold - the index for each dimension that the roll starts to wrap 36 // back to the front 37 template <typename T> 38 void DoRoll(OpKernelContext* context, const int64 num_elements, 39 const int num_dims, const gtl::ArraySlice<int>& dim_size, 40 const T* input, T* output, const gtl::ArraySlice<int>& threshold, 41 const gtl::ArraySlice<int64>& dim_range) { 42 auto work = [input, output, num_dims, &dim_size, &threshold, &dim_range]( 43 int64 start, int64 end) { 44 // array of indices for each dimension 45 gtl::InlinedVector<int, 4> indices(num_dims); 46 int offset = 0; // the shift along the flattened tensor for current element 47 // initialize indices and offset 48 for (int i = 0; i < num_dims; i++) { 49 // stride is the number of indices over in the flattened tensor 50 // you need to skip in order to make it over to an adjacent element 51 // along a dimension. dim_size[i] != 0 because we set it to max(dim, 1) 52 const int64 stride = dim_range[i] / dim_size[i]; 53 const int shift = dim_size[i] - threshold[i]; 54 const int indx = (start / stride) % dim_size[i]; 55 indices[i] = indx; 56 // calculate dimension index after the shift 57 const int shifted_indx = (indx + shift) % dim_size[i]; 58 offset += (shifted_indx - indx) * stride; 59 } 60 61 for (int64 i = start; i < end; i++) { 62 output[i + offset] = input[i]; 63 // create next combination of indices 64 // while at it adjust offset if needed 65 for (int j = num_dims - 1; j >= 0; j--) { 66 const int indx = (indices[j] + 1) % dim_size[j]; 67 indices[j] = indx; 68 if (indx != 0) { 69 if (indx == threshold[j]) { // we've reached the threshold 70 // dim_range[j] = threshold[j] + shift[j] 71 // offset = shift[j] + ... other offsets 72 // offset - dim_range[j] = -threshold[j] + ... other offsets 73 // thus we undo our previous offset as well as add a new offset of 74 // -threshold[j] in one operation 75 offset -= dim_range[j]; // now wraps around 76 } 77 break; // indx != 0 don't need to carry 78 } else if (threshold[j] != 0) { // if threshold is 0 shift is 0 79 offset += dim_range[j]; // indx became 0 so reverse wrap around 80 } 81 } 82 } 83 }; 84 // Shard 85 auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); 86 // 15 - expiramentally determined with float and bool types 87 const int cost_per_element = 15 * sizeof(T); // rough esitmate 88 Shard(worker_threads->num_threads, worker_threads->workers, num_elements, 89 cost_per_element, std::move(work)); 90 } 91 92 // dim_size - the size of each dimension 93 // dim_range - the number of indices over in the flattened tensor 94 // you need to skip in order to make it over from one side of a dimension 95 // to the other. Used to make the shifts wrap around after a threshold. 96 // threshold - the index for each dimension that the roll starts to wrap 97 // back to the front 98 // isd - inner shift dimension 99 template <typename T> 100 // Use memcpy to copy memory in groups when the data type supports memcpy 101 void DoRollWithMemcpy(OpKernelContext* context, const int64 num_elements, 102 const int num_dims, const gtl::ArraySlice<int>& dim_size, 103 const T* input, T* output, 104 const gtl::ArraySlice<int>& threshold, 105 const gtl::ArraySlice<int64>& dim_range, 106 const int64 isd) { 107 auto work = [input, output, num_dims, &dim_size, &threshold, &dim_range, isd]( 108 int64 start, int64 end) { 109 // the number of indices over in the flattened tensor you need to skip in 110 // order to make it over from one side of the isd to the other 111 const int64 isd_range = std::max<int>(dim_range[isd], 1); 112 // the distance along the flattend tensor to the next element in the isd 113 const int64 isd_stride = isd_range / std::max<int>(dim_size[isd], 1); 114 115 // start and end represent the i-th group currently so we will convert 116 // them into numbers representing the i-th elements. 117 // there are 2 groups per isd one for all elements before threshold[isd] 118 // and another for all elements after threshold[isd]. 119 const int64 start_remainder = (start % 2) * threshold[isd] * isd_stride; 120 const int64 end_remainder = (end % 2) * threshold[isd] * isd_stride; 121 start = (start / 2) * isd_range + start_remainder; 122 end = (end / 2) * isd_range + end_remainder; 123 124 const T* in_ptr = &input[0]; 125 T* out_ptr = &output[0]; 126 in_ptr += start; 127 out_ptr += start; 128 129 // array of indices for each dimension 130 // indicies = [i, j, k, l, m, n] 131 gtl::InlinedVector<int, 4> indicies(num_dims); 132 // the offset needed to make all inner non-shifting dimensions become 0 133 int64 remainder_offset = 0; 134 // initialize indicies 135 for (int i = 0; i < num_dims; i++) { 136 // stride is the number of indices over in the flattened tensor 137 // you need to skip in order to make it over to an adjacent element 138 // along a dimension. dim_size[i] != 0 because we set it to max(dim, 1) 139 const int64 stride = dim_range[i] / dim_size[i]; 140 const int shift = dim_size[i] - threshold[i]; 141 const int indx = (start / stride) % dim_size[i]; 142 indicies[i] = indx; 143 // calculate dimension index after the shift 144 int out_indx = (indx + shift) % dim_size[i]; 145 if (i > isd) { 146 // trailing zeroes for indices after the inner shifted dimension 147 out_indx = 0; 148 remainder_offset += (out_indx - indx) * stride; 149 } 150 out_ptr += (out_indx - indx) * stride; 151 } 152 // set trailing zeroes for indices after the inner shifted dimension 153 for (int i = num_dims - 1; i > isd; i--) indicies[i] = 0; 154 155 // the number of indices in the isd dimension the next group will skip 156 // to make it to the next threshold or end point 157 int isd_indx_skip = 0; 158 // the size of the next group 159 int64 group_size = 0; 160 // initialize isd_indx_skip and group_size 161 if (indicies[isd] < threshold[isd]) { 162 isd_indx_skip = threshold[isd] - indicies[isd]; 163 group_size = isd_indx_skip * isd_stride + remainder_offset; 164 } else { 165 isd_indx_skip = dim_size[isd] - indicies[isd]; 166 group_size = isd_indx_skip * isd_stride + remainder_offset; 167 } 168 169 int64 i = start; 170 while (i < end) { 171 // copy group of elements 172 memcpy(out_ptr, in_ptr, group_size * sizeof(T)); 173 174 // shift i and the pointers over to the next group position 175 i += group_size; 176 out_ptr += group_size; 177 in_ptr += group_size; 178 179 // produce next combination of indices and adjust the out_ptr position 180 // to fix the offset if necessary 181 // the isd (inner shift dim) should skip to next threshold or endpoint 182 // all dimensions to the left increment by 1 when a digit is carried 183 // all dimensions to the right remain set to 0 184 // +1 +1 +1 +isd_indx_skip 185 // indicies = [i, j, k, l, 0, 0] 186 // ^isd 187 for (int j = isd; j >= 0; j--) { 188 int inc = 1; 189 if (j == isd) inc = isd_indx_skip; 190 const int indx = (indicies[j] + inc) % dim_size[j]; 191 indicies[j] = indx; 192 if (indx != 0) { 193 if (indx == threshold[j]) { 194 out_ptr -= dim_range[j]; // now wraps around 195 } 196 break; // indx != 0 don't need to carry 197 } else if (threshold[j] != 0) { // if threshold is 0 shift is 0 198 out_ptr += dim_range[j]; // indx became 0 so reverse wrap around 199 } 200 } 201 202 // set isd_indx_skip and group_size for next iteration 203 if (indicies[isd] < threshold[isd]) { 204 isd_indx_skip = threshold[isd] - indicies[isd]; 205 group_size = isd_indx_skip * isd_stride; 206 } else { 207 isd_indx_skip = dim_size[isd] - indicies[isd]; 208 group_size = isd_indx_skip * isd_stride; 209 } 210 } 211 }; 212 // Shard 213 auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); 214 const int64 ave_group_size = dim_range[isd] / 2; 215 const int total_work = 2 * num_elements / std::max<int>(dim_range[isd], 1); 216 // 25000 - expiramentally determined with float and bool types 217 const int cost_per_group = 25000 * sizeof(T) * ave_group_size; 218 Shard(worker_threads->num_threads, worker_threads->workers, total_work, 219 cost_per_group, std::move(work)); 220 } 221 222 template <typename Device, typename T, typename Tshift, typename Taxis> 223 class RollOp : public OpKernel { 224 public: 225 explicit RollOp(OpKernelConstruction* context) : OpKernel(context) {} 226 227 void Compute(OpKernelContext* context) override { 228 // Grab the input tensor 229 const Tensor& input = context->input(0); 230 const Tensor& shift = context->input(1); 231 const Tensor& axis = context->input(2); 232 233 auto shift_flat = shift.flat<Tshift>(); 234 auto axis_flat = axis.flat<Taxis>(); 235 236 OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(input.shape()), 237 errors::InvalidArgument("input must be 1-D or higher")); 238 OP_REQUIRES(context, shift.shape().dims() <= 1, 239 errors::InvalidArgument( 240 "shift must be a scalar or a 1-D vector. Found: ", 241 shift.shape().DebugString())); 242 OP_REQUIRES(context, axis.shape().dims() <= 1, 243 errors::InvalidArgument( 244 "axis must be a scalar or a 1-D vector. Found: ", 245 axis.shape().DebugString())); 246 OP_REQUIRES( 247 context, shift.shape() == axis.shape(), 248 errors::InvalidArgument("shift and axis must have the same size")); 249 const int64 num_elements = input.NumElements(); 250 const int num_shifts = static_cast<int>(shift_flat.size()); 251 const int num_dims = input.dims(); 252 253 // if there are any duplicate axes, shift_mod_sum will have the 254 // total modulo sum of shifts for each dimension 255 gtl::InlinedVector<int, 4> shift_mod_sum(num_dims, 0); 256 for (int i = 0; i < num_shifts; i++) { 257 const int axis = axis_flat(i); 258 OP_REQUIRES(context, axis < num_dims, 259 errors::InvalidArgument("axis ", axis, " is out of range")); 260 const int ds = std::max<int>(static_cast<int>(input.dim_size(axis)), 1); 261 const int sum = shift_mod_sum[axis] + static_cast<int>(shift_flat(i)); 262 // modulo that works with negatives: ((x % y) + y) % y 263 shift_mod_sum[axis] = (sum % ds + ds) % ds; 264 } 265 // the size of each dimension 266 gtl::InlinedVector<int, 4> dim_size(num_dims); 267 // threshold[i] is the index that the roll starts to wrap back to the front 268 gtl::InlinedVector<int, 4> threshold(num_dims); 269 // dim_range is the number of indices over in the flattened tensor 270 // you need to skip in order to make it over from one side of a dimension 271 // to the other. Used to make the shifts wrap around after a threshold. 272 gtl::InlinedVector<int64, 4> dim_range(num_dims); 273 int64 dim_size_prod = 1; // dimension size product 274 // inner shift dimension (inner most shifted dimension) 275 int64 isd = 0; 276 for (int i = num_dims - 1; i >= 0; i--) { 277 if (isd == 0 && shift_mod_sum[i] != 0) isd = i; 278 const int ds = std::max<int>(static_cast<int>(input.dim_size(i)), 1); 279 dim_size[i] = ds; 280 threshold[i] = (ds - shift_mod_sum[i]) % ds; 281 dim_size_prod *= static_cast<int64>(input.dim_size(i)); 282 dim_range[i] = dim_size_prod; 283 } 284 285 Tensor* output = NULL; 286 OP_REQUIRES_OK(context, 287 context->allocate_output(0, input.shape(), &output)); 288 auto input_flat = input.flat<T>().data(); 289 auto output_flat = output->flat<T>().data(); 290 291 if (std::is_same<Device, CPUDevice>::value) { 292 if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) { 293 // V2 copies memory in groups instead of element by element 294 DoRollWithMemcpy<T>(context, num_elements, num_dims, dim_size, 295 input_flat, output_flat, threshold, dim_range, isd); 296 } else { 297 // incase memcpy does not work for current data type 298 DoRoll<T>(context, num_elements, num_dims, dim_size, input_flat, 299 output_flat, threshold, dim_range); 300 } 301 } 302 } 303 }; 304 305 // Register the CPU kernels. 306 #define REGISTER_CPU(type) \ 307 REGISTER_KERNEL_BUILDER(Name("Roll") \ 308 .Device(DEVICE_CPU) \ 309 .TypeConstraint<type>("T") \ 310 .TypeConstraint<int32>("Tshift") \ 311 .TypeConstraint<int32>("Taxis"), \ 312 RollOp<CPUDevice, type, int32, int32>) \ 313 REGISTER_KERNEL_BUILDER(Name("Roll") \ 314 .Device(DEVICE_CPU) \ 315 .TypeConstraint<type>("T") \ 316 .TypeConstraint<int64>("Tshift") \ 317 .TypeConstraint<int32>("Taxis"), \ 318 RollOp<CPUDevice, type, int64, int32>) \ 319 REGISTER_KERNEL_BUILDER(Name("Roll") \ 320 .Device(DEVICE_CPU) \ 321 .TypeConstraint<type>("T") \ 322 .TypeConstraint<int32>("Tshift") \ 323 .TypeConstraint<int64>("Taxis"), \ 324 RollOp<CPUDevice, type, int32, int64>) \ 325 REGISTER_KERNEL_BUILDER(Name("Roll") \ 326 .Device(DEVICE_CPU) \ 327 .TypeConstraint<type>("T") \ 328 .TypeConstraint<int64>("Tshift") \ 329 .TypeConstraint<int64>("Taxis"), \ 330 RollOp<CPUDevice, type, int64, int64>) 331 332 TF_CALL_ALL_TYPES(REGISTER_CPU); 333 #undef REGISTER_CPU 334 } // namespace tensorflow 335