Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/nn_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include "tensorflow/core/kernels/topk_op.h"
     21 
     22 #include <algorithm>
     23 #include <numeric>
     24 #include <vector>
     25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     26 #include "tensorflow/core/framework/op_kernel.h"
     27 #include "tensorflow/core/framework/register_types.h"
     28 #include "tensorflow/core/framework/tensor.h"
     29 #include "tensorflow/core/framework/tensor_shape.h"
     30 #include "tensorflow/core/framework/types.h"
     31 #include "tensorflow/core/lib/gtl/top_n.h"
     32 #include "tensorflow/core/util/work_sharder.h"
     33 
     34 namespace tensorflow {
     35 
     36 typedef Eigen::ThreadPoolDevice CPUDevice;
     37 typedef Eigen::GpuDevice GPUDevice;
     38 
     39 template <typename Device, typename T>
     40 class TopK : public OpKernel {
     41  public:
     42   explicit TopK(OpKernelConstruction* context) : OpKernel(context) {
     43     OP_REQUIRES_OK(context, context->GetAttr("sorted", &sorted_));
     44     if (num_inputs() < 2) {  // k is an attr (TopK).
     45       OP_REQUIRES_OK(context, context->GetAttr("k", &k_));
     46     } else {  // k is an input (TopKV2), so we won't know it until Compute.
     47       k_ = -1;
     48     }
     49   }
     50 
     51   void Compute(OpKernelContext* context) override {
     52     int k = k_;
     53     if (num_inputs() >= 2) {
     54       const auto& k_in = context->input(1);
     55       OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_in.shape()),
     56                   errors::InvalidArgument("k must be scalar, got shape ",
     57                                           k_in.shape().DebugString()));
     58       k = k_in.scalar<int32>()();
     59     }
     60     OP_REQUIRES(context, k >= 0,
     61                 errors::InvalidArgument("Need k >= 0, got ", k));
     62     const auto& input_in = context->input(0);
     63     OP_REQUIRES(context, input_in.dims() >= 1,
     64                 errors::InvalidArgument("input must be >= 1-D, got shape ",
     65                                         input_in.shape().DebugString()));
     66     OP_REQUIRES(context, input_in.dim_size(input_in.dims() - 1) >= k,
     67                 errors::InvalidArgument(
     68                     "input must have at least k columns. Had ",
     69                     input_in.dim_size(input_in.dims() - 1), ", needed ", k));
     70 
     71     const auto& input = input_in.flat_inner_dims<T>();
     72 
     73     const int64 num_rows = input.dimension(0);  // generally batch_size
     74     const int64 num_cols = input.dimension(1);
     75 
     76     TensorShape output_shape = input_in.shape();
     77     output_shape.set_dim(input_in.dims() - 1, k);
     78     Tensor* values_out = nullptr;
     79     OP_REQUIRES_OK(context,
     80                    context->allocate_output(0, output_shape, &values_out));
     81     Tensor* indices_out = nullptr;
     82     OP_REQUIRES_OK(context,
     83                    context->allocate_output(1, output_shape, &indices_out));
     84 
     85     // Nothing to do for top-nothing.
     86     if (k == 0) return;
     87 
     88     auto values = values_out->flat_inner_dims<T>();
     89     auto indices = indices_out->flat_inner_dims<int32>();
     90     Status s = functor::TopKFunctor<Device, T>::Compute(
     91         context, sorted_, k, input, num_rows, num_cols, values, indices);
     92     OP_REQUIRES_OK(context, s);
     93   }
     94 
     95  private:
     96   int k_;
     97   bool sorted_;
     98 };
     99 
    100 namespace functor {
    101 
    102 template <typename T>
    103 struct TopKFunctor<CPUDevice, T> {
    104   static EIGEN_ALWAYS_INLINE Status
    105   Compute(OpKernelContext* context, bool sorted, int k,
    106           const typename TTypes<T, 2>::ConstTensor& input, const int64 num_rows,
    107           const int64 num_cols, typename TTypes<T, 2>::Tensor values,
    108           typename TTypes<int, 2>::Tensor indices) {
    109     const CPUDevice& d = context->eigen_device<CPUDevice>();
    110 
    111     // Special case for k == 1.
    112     if (k == 1) {
    113 #ifdef EIGEN_HAS_INDEX_LIST
    114       typename Eigen::IndexList<Eigen::type2index<1>> reduce_on_cols;
    115       typename Eigen::IndexList<int, Eigen::type2index<1>> rows_by_one;
    116       rows_by_one.set(0, num_rows);
    117 #else
    118       Eigen::array<int, 1> reduce_on_cols = {1};
    119       Eigen::array<int, 2> rows_by_one = {static_cast<int>(num_rows), 1};
    120 #endif
    121 
    122       values.device(d) =
    123           input.maximum(/*dims=*/reduce_on_cols).eval().reshape(rows_by_one);
    124       // Get the indices of the maximum values.
    125       for (int r = 0; r < num_rows; ++r) {
    126         for (int c = 0; c < num_cols; ++c) {
    127           if (values(r, 0) == input(r, c)) {
    128             indices(r, 0) = c;
    129             break;
    130           }
    131         }
    132       }
    133 
    134       return Status::OK();
    135     }
    136 
    137     auto SortIndices = [&, context](int start_batch, int limit_batch) {
    138       for (int32 b = start_batch; b < limit_batch; ++b) {
    139         const T* input_data = &input(b, 0);
    140         const auto stable_comp = [input_data](const int32 a, const int32 b) {
    141           if (input_data[b] < input_data[a]) {
    142             return true;
    143           } else if (input_data[b] > input_data[a]) {
    144             return false;
    145           } else {
    146             return a < b;
    147           }
    148         };
    149         const auto comp = [input_data](const int32 a, const int32 b) {
    150           return input_data[b] < input_data[a];
    151         };
    152         // TODO(ebrevdo): For large k < num_cols, instead of using
    153         // TopN, it may be faster to create a temporary vector of
    154         // values 0..num_cols - 1 and then use std::partial_sort_copy
    155         // of this into indices. Choosing the appropriate minimum k or
    156         // ratio of k/num_cols will require some experimentation.
    157         if (k == num_cols) {
    158           auto* begin = &indices(b, 0);
    159           auto* end = &indices(b, k);
    160           // Set the initial array of indices 0 ... k - 1.
    161           std::iota(begin, end, 0);
    162           // We want an in-place sort, but we can cheat because we're sorting
    163           // indices that started out sorted.  First, do a std::sort, which
    164           // is notably faster than std::stable_sort.
    165           std::sort(begin, end, comp);
    166           // Then, for runs of adjacent elements that were equal, sort the
    167           // indices in those runs in increasing order.
    168           for (auto* run_begin = begin; run_begin != end;) {
    169             auto* run_end = run_begin + 1;
    170             if (run_end == end) break;
    171             if (input_data[*run_begin] == input_data[*run_end]) {
    172               while (++run_end != end) {
    173                 if (input_data[*run_begin] != input_data[*run_end]) break;
    174               }
    175               std::sort(run_begin, run_end);
    176             }
    177             run_begin = run_end;
    178           }
    179         } else {
    180           // Use the TopN heap object to sort.
    181           gtl::TopN<int32, decltype(stable_comp)> filter(k, stable_comp);
    182           filter.reserve(num_cols);
    183           for (int32 c = 0; c < num_cols; ++c) {
    184             filter.push(c);
    185           }
    186 
    187           int32 i = 0;
    188           if (sorted) {
    189             std::unique_ptr<std::vector<int32>> top_k(filter.Extract());
    190             for (auto top_k_it = top_k->begin(); top_k_it != top_k->end();
    191                  ++top_k_it, ++i) {
    192               indices(b, i) = *top_k_it;
    193             }
    194           } else {
    195             for (auto top_k_it = filter.unsorted_begin();
    196                  top_k_it != filter.unsorted_end(); ++top_k_it, ++i) {
    197               indices(b, i) = *top_k_it;
    198             }
    199           }
    200         }
    201         // Now that the indices are sorted, copy the values over in
    202         // sorted order.
    203         std::transform(&indices(b, 0), &indices(b, k), &values(b, 0),
    204                        [b, &input](const int32 loc) { return input(b, loc); });
    205       }  // for (int32 b = ...
    206     };
    207 
    208     // Guesstimate of cost; 4*N*log(K) where N == num_cols.
    209     // If K == N, assume the cost is N*log(K + 1).
    210     const double cmp_cost = 3 * Eigen::TensorOpCost::AddCost<int32>() +
    211                             Eigen::TensorOpCost::AddCost<T>();
    212     const double base_cost =
    213         cmp_cost *
    214         static_cast<double>(num_cols *
    215                             Eigen::numext::log2(static_cast<float>(k + 1)));
    216     const double sort_cost = (k == num_cols) ? base_cost : 4 * base_cost;
    217     const double copy_cost = 2 * k * Eigen::TensorOpCost::AddCost<T>();
    218     const double total_cost = sort_cost + copy_cost;
    219     const int64 final_cost = (total_cost >= static_cast<double>(kint64max))
    220                                  ? kint64max
    221                                  : static_cast<int64>(total_cost);
    222     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
    223     Shard(worker_threads.num_threads, worker_threads.workers, num_rows,
    224           final_cost, SortIndices);
    225 
    226     return Status::OK();
    227   }
    228 };
    229 
    230 }  // namespace functor
    231 
    232 #define REGISTER_KERNELS_NAME(name, type)                       \
    233   REGISTER_KERNEL_BUILDER(                                      \
    234       Name(#name).Device(DEVICE_CPU).TypeConstraint<type>("T"), \
    235       TopK<CPUDevice, type>)
    236 
    237 #define REGISTER_KERNELS(type)       \
    238   REGISTER_KERNELS_NAME(TopK, type); \
    239   REGISTER_KERNELS_NAME(TopKV2, type)
    240 
    241 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
    242 #undef REGISTER_KERNELS_NAME
    243 #undef REGISTER_KERNELS
    244 
    245 #ifdef GOOGLE_CUDA
    246 
    247 namespace functor {
    248 #define DECLARE_GPU_SPEC(T)                                                  \
    249   template <>                                                                \
    250   Status TopKFunctor<GPUDevice, T>::Compute(                                 \
    251       OpKernelContext* context, bool sorted, int k,                          \
    252       const typename TTypes<T, 2>::ConstTensor& input, const int64 num_rows, \
    253       const int64 num_cols, typename TTypes<T, 2>::Tensor values,            \
    254       typename TTypes<int, 2>::Tensor indices);                              \
    255   extern template struct functor::TopKFunctor<GPUDevice, T>;
    256 
    257 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
    258 TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC);
    259 
    260 #undef DECLARE_GPU_SPEC
    261 
    262 }  // namespace functor
    263 
    264 #define REGISTER_KERNELS(type)                                   \
    265   REGISTER_KERNEL_BUILDER(                                       \
    266       Name("TopK").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
    267       TopK<GPUDevice, type>)                                     \
    268   REGISTER_KERNEL_BUILDER(Name("TopKV2")                         \
    269                               .Device(DEVICE_GPU)                \
    270                               .TypeConstraint<type>("T")         \
    271                               .HostMemory("k"),                  \
    272                           TopK<GPUDevice, type>)
    273 
    274 TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
    275 TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS);
    276 
    277 #undef REGISTER_KERNELS
    278 
    279 #endif  // end GOOGLE_CUDA
    280 
    281 }  // end namespace tensorflow
    282