Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/common_shape_fns.h"
     17 #include "tensorflow/core/framework/op.h"
     18 #include "tensorflow/core/framework/op_kernel.h"
     19 #include "tensorflow/core/framework/register_types.h"
     20 #include "tensorflow/core/framework/register_types_traits.h"
     21 #include "tensorflow/core/framework/shape_inference.h"
     22 #include "tensorflow/core/lib/gtl/array_slice.h"
     23 #include "tensorflow/core/platform/types.h"
     24 #include "tensorflow/core/util/work_sharder.h"
     25 
     26 namespace tensorflow {
     27 
     28 #define EIGEN_USE_THREADS
     29 using CPUDevice = Eigen::ThreadPoolDevice;
     30 
     31 // dim_size - the size of each dimension
     32 // dim_range - the number of indices over in the flattened tensor
     33 //    you need to skip in order to make it over from one side of a dimension
     34 //    to the other. Used to make the shifts wrap around after a threshold.
     35 // threshold - the index for each dimension that the roll starts to wrap
     36 //    back to the front
     37 template <typename T>
     38 void DoRoll(OpKernelContext* context, const int64 num_elements,
     39             const int num_dims, const gtl::ArraySlice<int>& dim_size,
     40             const T* input, T* output, const gtl::ArraySlice<int>& threshold,
     41             const gtl::ArraySlice<int64>& dim_range) {
     42   auto work = [input, output, num_dims, &dim_size, &threshold, &dim_range](
     43                   int64 start, int64 end) {
     44     // array of indices for each dimension
     45     gtl::InlinedVector<int, 4> indices(num_dims);
     46     int offset = 0;  // the shift along the flattened tensor for current element
     47     // initialize indices and offset
     48     for (int i = 0; i < num_dims; i++) {
     49       // stride is the number of indices over in the flattened tensor
     50       // you need to skip in order to make it over to an adjacent element
     51       // along a dimension. dim_size[i] != 0 because we set it to max(dim, 1)
     52       const int64 stride = dim_range[i] / dim_size[i];
     53       const int shift = dim_size[i] - threshold[i];
     54       const int indx = (start / stride) % dim_size[i];
     55       indices[i] = indx;
     56       // calculate dimension index after the shift
     57       const int shifted_indx = (indx + shift) % dim_size[i];
     58       offset += (shifted_indx - indx) * stride;
     59     }
     60 
     61     for (int64 i = start; i < end; i++) {
     62       output[i + offset] = input[i];
     63       // create next combination of indices
     64       // while at it adjust offset if needed
     65       for (int j = num_dims - 1; j >= 0; j--) {
     66         const int indx = (indices[j] + 1) % dim_size[j];
     67         indices[j] = indx;
     68         if (indx != 0) {
     69           if (indx == threshold[j]) {  // we've reached the threshold
     70             // dim_range[j] = threshold[j] + shift[j]
     71             // offset = shift[j] + ... other offsets
     72             // offset - dim_range[j] = -threshold[j] + ... other offsets
     73             // thus we undo our previous offset as well as add a new offset of
     74             // -threshold[j] in one operation
     75             offset -= dim_range[j];  // now wraps around
     76           }
     77           break;                         // indx != 0 don't need to carry
     78         } else if (threshold[j] != 0) {  // if threshold is 0 shift is 0
     79           offset += dim_range[j];        // indx became 0 so reverse wrap around
     80         }
     81       }
     82     }
     83   };
     84   // Shard
     85   auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
     86   // 15 - expiramentally determined with float and bool types
     87   const int cost_per_element = 15 * sizeof(T);  // rough esitmate
     88   Shard(worker_threads->num_threads, worker_threads->workers, num_elements,
     89         cost_per_element, std::move(work));
     90 }
     91 
     92 // dim_size - the size of each dimension
     93 // dim_range - the number of indices over in the flattened tensor
     94 //    you need to skip in order to make it over from one side of a dimension
     95 //    to the other. Used to make the shifts wrap around after a threshold.
     96 // threshold - the index for each dimension that the roll starts to wrap
     97 //    back to the front
     98 // isd - inner shift dimension
     99 template <typename T>
    100 // Use memcpy to copy memory in groups when the data type supports memcpy
    101 void DoRollWithMemcpy(OpKernelContext* context, const int64 num_elements,
    102                       const int num_dims, const gtl::ArraySlice<int>& dim_size,
    103                       const T* input, T* output,
    104                       const gtl::ArraySlice<int>& threshold,
    105                       const gtl::ArraySlice<int64>& dim_range,
    106                       const int64 isd) {
    107   auto work = [input, output, num_dims, &dim_size, &threshold, &dim_range, isd](
    108                   int64 start, int64 end) {
    109     // the number of indices over in the flattened tensor you need to skip in
    110     // order to make it over from one side of the isd to the other
    111     const int64 isd_range = std::max<int>(dim_range[isd], 1);
    112     // the distance along the flattend tensor to the next element in the isd
    113     const int64 isd_stride = isd_range / std::max<int>(dim_size[isd], 1);
    114 
    115     // start and end represent the i-th group currently so we will convert
    116     // them into numbers representing the i-th elements.
    117     // there are 2 groups per isd one for all elements before threshold[isd]
    118     // and another for all elements after threshold[isd].
    119     const int64 start_remainder = (start % 2) * threshold[isd] * isd_stride;
    120     const int64 end_remainder = (end % 2) * threshold[isd] * isd_stride;
    121     start = (start / 2) * isd_range + start_remainder;
    122     end = (end / 2) * isd_range + end_remainder;
    123 
    124     const T* in_ptr = &input[0];
    125     T* out_ptr = &output[0];
    126     in_ptr += start;
    127     out_ptr += start;
    128 
    129     // array of indices for each dimension
    130     // indicies = [i, j, k, l, m, n]
    131     gtl::InlinedVector<int, 4> indicies(num_dims);
    132     // the offset needed to make all inner non-shifting dimensions become 0
    133     int64 remainder_offset = 0;
    134     // initialize indicies
    135     for (int i = 0; i < num_dims; i++) {
    136       // stride is the number of indices over in the flattened tensor
    137       // you need to skip in order to make it over to an adjacent element
    138       // along a dimension. dim_size[i] != 0 because we set it to max(dim, 1)
    139       const int64 stride = dim_range[i] / dim_size[i];
    140       const int shift = dim_size[i] - threshold[i];
    141       const int indx = (start / stride) % dim_size[i];
    142       indicies[i] = indx;
    143       // calculate dimension index after the shift
    144       int out_indx = (indx + shift) % dim_size[i];
    145       if (i > isd) {
    146         // trailing zeroes for indices after the inner shifted dimension
    147         out_indx = 0;
    148         remainder_offset += (out_indx - indx) * stride;
    149       }
    150       out_ptr += (out_indx - indx) * stride;
    151     }
    152     // set trailing zeroes for indices after the inner shifted dimension
    153     for (int i = num_dims - 1; i > isd; i--) indicies[i] = 0;
    154 
    155     // the number of indices in the isd dimension the next group will skip
    156     // to make it to the next threshold or end point
    157     int isd_indx_skip = 0;
    158     // the size of the next group
    159     int64 group_size = 0;
    160     // initialize isd_indx_skip and group_size
    161     if (indicies[isd] < threshold[isd]) {
    162       isd_indx_skip = threshold[isd] - indicies[isd];
    163       group_size = isd_indx_skip * isd_stride + remainder_offset;
    164     } else {
    165       isd_indx_skip = dim_size[isd] - indicies[isd];
    166       group_size = isd_indx_skip * isd_stride + remainder_offset;
    167     }
    168 
    169     int64 i = start;
    170     while (i < end) {
    171       // copy group of elements
    172       memcpy(out_ptr, in_ptr, group_size * sizeof(T));
    173 
    174       // shift i and the pointers over to the next group position
    175       i += group_size;
    176       out_ptr += group_size;
    177       in_ptr += group_size;
    178 
    179       // produce next combination of indices and adjust the out_ptr position
    180       // to fix the offset if necessary
    181       // the isd (inner shift dim) should skip to next threshold or endpoint
    182       // all dimensions to the left increment by 1 when a digit is carried
    183       // all dimensions to the right remain set to 0
    184       //            +1 +1 +1 +isd_indx_skip
    185       // indicies = [i, j, k, l, 0, 0]
    186       //                      ^isd
    187       for (int j = isd; j >= 0; j--) {
    188         int inc = 1;
    189         if (j == isd) inc = isd_indx_skip;
    190         const int indx = (indicies[j] + inc) % dim_size[j];
    191         indicies[j] = indx;
    192         if (indx != 0) {
    193           if (indx == threshold[j]) {
    194             out_ptr -= dim_range[j];  // now wraps around
    195           }
    196           break;                         // indx != 0 don't need to carry
    197         } else if (threshold[j] != 0) {  // if threshold is 0 shift is 0
    198           out_ptr += dim_range[j];       // indx became 0 so reverse wrap around
    199         }
    200       }
    201 
    202       // set isd_indx_skip and group_size for next iteration
    203       if (indicies[isd] < threshold[isd]) {
    204         isd_indx_skip = threshold[isd] - indicies[isd];
    205         group_size = isd_indx_skip * isd_stride;
    206       } else {
    207         isd_indx_skip = dim_size[isd] - indicies[isd];
    208         group_size = isd_indx_skip * isd_stride;
    209       }
    210     }
    211   };
    212   // Shard
    213   auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
    214   const int64 ave_group_size = dim_range[isd] / 2;
    215   const int total_work = 2 * num_elements / std::max<int>(dim_range[isd], 1);
    216   // 25000 - expiramentally determined with float and bool types
    217   const int cost_per_group = 25000 * sizeof(T) * ave_group_size;
    218   Shard(worker_threads->num_threads, worker_threads->workers, total_work,
    219         cost_per_group, std::move(work));
    220 }
    221 
    222 template <typename Device, typename T, typename Tshift, typename Taxis>
    223 class RollOp : public OpKernel {
    224  public:
    225   explicit RollOp(OpKernelConstruction* context) : OpKernel(context) {}
    226 
    227   void Compute(OpKernelContext* context) override {
    228     // Grab the input tensor
    229     const Tensor& input = context->input(0);
    230     const Tensor& shift = context->input(1);
    231     const Tensor& axis = context->input(2);
    232 
    233     auto shift_flat = shift.flat<Tshift>();
    234     auto axis_flat = axis.flat<Taxis>();
    235 
    236     OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(input.shape()),
    237                 errors::InvalidArgument("input must be 1-D or higher"));
    238     OP_REQUIRES(context, shift.shape().dims() <= 1,
    239                 errors::InvalidArgument(
    240                     "shift must be a scalar or a 1-D vector. Found: ",
    241                     shift.shape().DebugString()));
    242     OP_REQUIRES(context, axis.shape().dims() <= 1,
    243                 errors::InvalidArgument(
    244                     "axis must be a scalar or a 1-D vector. Found: ",
    245                     axis.shape().DebugString()));
    246     OP_REQUIRES(
    247         context, shift.shape() == axis.shape(),
    248         errors::InvalidArgument("shift and axis must have the same size"));
    249     const int64 num_elements = input.NumElements();
    250     const int num_shifts = static_cast<int>(shift_flat.size());
    251     const int num_dims = input.dims();
    252 
    253     // if there are any duplicate axes, shift_mod_sum will have the
    254     // total modulo sum of shifts for each dimension
    255     gtl::InlinedVector<int, 4> shift_mod_sum(num_dims, 0);
    256     for (int i = 0; i < num_shifts; i++) {
    257       const int axis = axis_flat(i);
    258       OP_REQUIRES(context, axis < num_dims,
    259                   errors::InvalidArgument("axis ", axis, " is out of range"));
    260       const int ds = std::max<int>(static_cast<int>(input.dim_size(axis)), 1);
    261       const int sum = shift_mod_sum[axis] + static_cast<int>(shift_flat(i));
    262       // modulo that works with negatives: ((x % y) + y) % y
    263       shift_mod_sum[axis] = (sum % ds + ds) % ds;
    264     }
    265     // the size of each dimension
    266     gtl::InlinedVector<int, 4> dim_size(num_dims);
    267     // threshold[i] is the index that the roll starts to wrap back to the front
    268     gtl::InlinedVector<int, 4> threshold(num_dims);
    269     // dim_range is the number of indices over in the flattened tensor
    270     // you need to skip in order to make it over from one side of a dimension
    271     // to the other. Used to make the shifts wrap around after a threshold.
    272     gtl::InlinedVector<int64, 4> dim_range(num_dims);
    273     int64 dim_size_prod = 1;  // dimension size product
    274     // inner shift dimension (inner most shifted dimension)
    275     int64 isd = 0;
    276     for (int i = num_dims - 1; i >= 0; i--) {
    277       if (isd == 0 && shift_mod_sum[i] != 0) isd = i;
    278       const int ds = std::max<int>(static_cast<int>(input.dim_size(i)), 1);
    279       dim_size[i] = ds;
    280       threshold[i] = (ds - shift_mod_sum[i]) % ds;
    281       dim_size_prod *= static_cast<int64>(input.dim_size(i));
    282       dim_range[i] = dim_size_prod;
    283     }
    284 
    285     Tensor* output = NULL;
    286     OP_REQUIRES_OK(context,
    287                    context->allocate_output(0, input.shape(), &output));
    288     auto input_flat = input.flat<T>().data();
    289     auto output_flat = output->flat<T>().data();
    290 
    291     if (std::is_same<Device, CPUDevice>::value) {
    292       if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
    293         // V2 copies memory in groups instead of element by element
    294         DoRollWithMemcpy<T>(context, num_elements, num_dims, dim_size,
    295                             input_flat, output_flat, threshold, dim_range, isd);
    296       } else {
    297         // incase memcpy does not work for current data type
    298         DoRoll<T>(context, num_elements, num_dims, dim_size, input_flat,
    299                   output_flat, threshold, dim_range);
    300       }
    301     }
    302   }
    303 };
    304 
    305 // Register the CPU kernels.
    306 #define REGISTER_CPU(type)                                       \
    307   REGISTER_KERNEL_BUILDER(Name("Roll")                           \
    308                               .Device(DEVICE_CPU)                \
    309                               .TypeConstraint<type>("T")         \
    310                               .TypeConstraint<int32>("Tshift")   \
    311                               .TypeConstraint<int32>("Taxis"),   \
    312                           RollOp<CPUDevice, type, int32, int32>) \
    313   REGISTER_KERNEL_BUILDER(Name("Roll")                           \
    314                               .Device(DEVICE_CPU)                \
    315                               .TypeConstraint<type>("T")         \
    316                               .TypeConstraint<int64>("Tshift")   \
    317                               .TypeConstraint<int32>("Taxis"),   \
    318                           RollOp<CPUDevice, type, int64, int32>) \
    319   REGISTER_KERNEL_BUILDER(Name("Roll")                           \
    320                               .Device(DEVICE_CPU)                \
    321                               .TypeConstraint<type>("T")         \
    322                               .TypeConstraint<int32>("Tshift")   \
    323                               .TypeConstraint<int64>("Taxis"),   \
    324                           RollOp<CPUDevice, type, int32, int64>) \
    325   REGISTER_KERNEL_BUILDER(Name("Roll")                           \
    326                               .Device(DEVICE_CPU)                \
    327                               .TypeConstraint<type>("T")         \
    328                               .TypeConstraint<int64>("Tshift")   \
    329                               .TypeConstraint<int64>("Taxis"),   \
    330                           RollOp<CPUDevice, type, int64, int64>)
    331 
    332 TF_CALL_ALL_TYPES(REGISTER_CPU);
    333 #undef REGISTER_CPU
    334 }  // namespace tensorflow
    335