Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_GATHER_FUNCTOR_H_
     17 #define TENSORFLOW_KERNELS_GATHER_FUNCTOR_H_
     18 
     19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     20 
     21 #include "tensorflow/core/framework/op_kernel.h"
     22 #include "tensorflow/core/framework/tensor_types.h"
     23 #include "tensorflow/core/framework/type_traits.h"
     24 #include "tensorflow/core/kernels/bounds_check.h"
     25 #include "tensorflow/core/platform/prefetch.h"
     26 #include "tensorflow/core/platform/types.h"
     27 #include "tensorflow/core/util/work_sharder.h"
     28 
     29 namespace tensorflow {
     30 typedef Eigen::ThreadPoolDevice CPUDevice;
     31 
     32 namespace functor {
     33 
     34 // Helper method to copy using memcpy.
     35 template <typename T, typename Index, typename SliceIndex,
     36           SliceIndex static_slice_elems>
     37 SliceIndex HandleCopies(OpKernelContext* ctx,
     38                         typename TTypes<T, 3>::ConstTensor params,
     39                         typename TTypes<Index>::ConstFlat indices,
     40                         SliceIndex slice_elems,
     41                         typename TTypes<T, 3>::Tensor out) {
     42   const SliceIndex indices_size = static_cast<SliceIndex>(indices.dimension(0));
     43   const SliceIndex batch_size = static_cast<SliceIndex>(params.dimension(0));
     44   const Index limit = static_cast<Index>(params.dimension(1));
     45   T* out_base = &out(0, 0, 0);
     46   const T* params_base = &params(0, 0, 0);
     47   if (static_slice_elems >= 0) {
     48     // Give compiler static knowledge of the number of elements/bytes
     49     slice_elems = static_slice_elems;
     50   }
     51   // Compute slice_bytes here so that static knowledge is available
     52   const size_t slice_bytes = slice_elems * sizeof(T);
     53   auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads();
     54   mutex mu;
     55   // Store the value of invalidate index for printing error information, it's a
     56   // shared variable.
     57   SliceIndex result = -1;
     58   auto work = [&](int64 start, int64 end) {
     59     SliceIndex batch_idx = static_cast<SliceIndex>(start / indices_size);
     60     SliceIndex indices_idx = static_cast<SliceIndex>(start % indices_size);
     61     SliceIndex batch_idx_end = static_cast<SliceIndex>(end / indices_size);
     62     SliceIndex indices_idx_end = static_cast<SliceIndex>(end % indices_size);
     63 
     64     while ((batch_idx < batch_idx_end) ||
     65            (batch_idx == batch_idx_end && indices_idx < indices_idx_end)) {
     66       SliceIndex i_next = indices_idx + 1;
     67       SliceIndex b_next = batch_idx + 1;
     68       if ((batch_idx == batch_idx_end && i_next < indices_idx_end) ||
     69           (i_next < indices_size)) {
     70         port::prefetch<port::PREFETCH_HINT_T0>(
     71             &params(batch_idx, indices(i_next), 0));
     72         port::prefetch<port::PREFETCH_HINT_T0>(&out(batch_idx, i_next, 0));
     73         b_next = batch_idx;
     74       } else if (b_next <= batch_idx_end) {
     75         port::prefetch<port::PREFETCH_HINT_T0>(&params(b_next, indices(0), 0));
     76         port::prefetch<port::PREFETCH_HINT_T0>(&out(b_next, 0, 0));
     77         i_next = 0;
     78       }
     79       const Index index = internal::SubtleMustCopy(indices(indices_idx));
     80       if (!FastBoundsCheck(index, limit)) {
     81         mutex_lock l(mu);
     82         result = indices_idx;
     83         return;
     84       }
     85       // Copy using memcpy if possible, otherwise an Eigen loop
     86       // TODO(cwhipkey): avoid linking to framework to get Allocator (to improve
     87       // ahead-of-time compilation binary size).
     88       if (is_simple_type<T>::value) {
     89         // Avoid auto-promotion to Index from SliceIndex by casting.
     90         memcpy(
     91             out_base + (batch_idx * indices_size + indices_idx) * slice_elems,
     92             params_base + (batch_idx * static_cast<SliceIndex>(limit) +
     93                            static_cast<SliceIndex>(index)) *
     94                               slice_elems,
     95             slice_bytes);
     96       } else {
     97         // For non-"simple" types (e.g. strings).
     98         out.template chip<1>(indices_idx) = params.template chip<1>(index);
     99       }
    100       indices_idx = i_next;
    101       batch_idx = b_next;
    102     }
    103   };
    104 
    105   Shard(worker_threads->num_threads, worker_threads->workers,
    106         batch_size * indices_size, slice_elems * sizeof(T), work);
    107   return result;
    108 }
    109 
    110 template <typename T, typename Index>
    111 struct GatherFunctorCPU {
    112   int64 operator()(OpKernelContext* ctx,
    113                    typename TTypes<T, 3>::ConstTensor params,
    114                    typename TTypes<Index>::ConstFlat indices,
    115                    typename TTypes<T, 3>::Tensor out) {
    116     const int64 N = indices.size();
    117     const int64 slice_size = out.dimension(2);
    118     int64 bad_i;
    119 
    120     bool use_large = (slice_size > std::numeric_limits<int32>::max() ||
    121                       params.size() > std::numeric_limits<int32>::max() ||
    122                       N > std::numeric_limits<int32>::max());
    123 #define CALL(elems)                                                      \
    124   do {                                                                   \
    125     if (use_large) {                                                     \
    126       bad_i = HandleCopies<T, Index, int64, elems>(ctx, params, indices, \
    127                                                    slice_size, out);     \
    128     } else {                                                             \
    129       const int32 small_slice = static_cast<int32>(slice_size);          \
    130       bad_i = HandleCopies<T, Index, int32, elems>(ctx, params, indices, \
    131                                                    small_slice, out);    \
    132     }                                                                    \
    133   } while (0)
    134 
    135     if (slice_size == 10)
    136       CALL(10);
    137     else if (slice_size == 20)
    138       CALL(20);
    139     else
    140       CALL(-1);
    141 #undef CALL
    142 
    143     return bad_i;
    144   }
    145 };
    146 
    147 template <typename Device, typename T, typename Index>
    148 struct GatherFunctor {
    149   int64 operator()(OpKernelContext* ctx,
    150                    typename TTypes<T, 3>::ConstTensor params,
    151                    typename TTypes<Index>::ConstFlat indices,
    152                    typename TTypes<T, 3>::Tensor out);
    153 };
    154 
    155 template <typename T, typename Index>
    156 struct GatherFunctor<CPUDevice, T, Index> {
    157   int64 operator()(OpKernelContext* ctx,
    158                    typename TTypes<T, 3>::ConstTensor params,
    159                    typename TTypes<Index>::ConstFlat indices,
    160                    typename TTypes<T, 3>::Tensor out) {
    161     return GatherFunctorCPU<T, Index>()(ctx, params, indices, out);
    162   }
    163 };
    164 
    165 }  // namespace functor
    166 }  // namespace tensorflow
    167 
    168 #endif  // TENSORFLOW_KERNELS_GATHER_FUNCTOR_H_
    169