Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/nn_ops.cc.
     17 #include "tensorflow/core/kernels/nth_element_op.h"
     18 
     19 #include <algorithm>
     20 #include <iostream>
     21 #include <vector>
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/register_types.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/framework/types.h"
     26 #include "tensorflow/core/platform/logging.h"
     27 #include "tensorflow/core/util/work_sharder.h"
     28 
     29 namespace tensorflow {
     30 
     31 typedef Eigen::ThreadPoolDevice CPUDevice;
     32 
     33 template <typename Device, typename T>
     34 class NthElementOp : public OpKernel {
     35  public:
     36   explicit NthElementOp(OpKernelConstruction* context) : OpKernel(context) {
     37     OP_REQUIRES_OK(context, context->GetAttr("reverse", &reverse_));
     38   }
     39 
     40   void Compute(OpKernelContext* context) override {
     41     // The second args is N, which must be a positive scalar.
     42     const auto& n_in = context->input(1);
     43     OP_REQUIRES(context, TensorShapeUtils::IsScalar(n_in.shape()),
     44                 errors::InvalidArgument("N must be scalar, got shape ",
     45                                         n_in.shape().DebugString()));
     46     int n = n_in.scalar<int32>()();
     47     OP_REQUIRES(context, n >= 0,
     48                 errors::InvalidArgument("Need n >= 0, got ", n));
     49 
     50     // The first args is input tensor, which must have 1 dimension at least.
     51     const Tensor& input_in = context->input(0);
     52     const int num_dims = input_in.dims();
     53     OP_REQUIRES(context, num_dims >= 1,
     54                 errors::InvalidArgument("Input must be >= 1-D, got shape ",
     55                                         input_in.shape().DebugString()));
     56     // The last dimension of input tensor must be greater than N.
     57     OP_REQUIRES(
     58         context, input_in.dim_size(num_dims - 1) > n,
     59         errors::InvalidArgument("Input must have at least n+1 columns"));
     60 
     61     // std::nth_element only support the nth-smallest selection.
     62     if (reverse_) {
     63       n = input_in.dim_size(num_dims - 1) - n - 1;
     64     }
     65 
     66     // Assume input_shape is [d1,d2,...dk], and output_shape is [d1,d2...dk-1].
     67     TensorShape out_shape;
     68     for (int i = 0; i < num_dims - 1; ++i) {
     69       out_shape.AddDim(input_in.dim_size(i));
     70     }
     71     Tensor* output_tensor = nullptr;
     72     OP_REQUIRES_OK(context,
     73                    context->allocate_output(0, out_shape, &output_tensor));
     74 
     75     functor::NthElementFunctor<Device, T> nthElementFunc;
     76     nthElementFunc(context, input_in, *output_tensor, n, reverse_);
     77   }
     78 
     79  private:
     80   bool reverse_;
     81 };
     82 
     83 namespace functor {
     84 
     85 template <typename T>
     86 struct NthElementFunctor<CPUDevice, T> {
     87   void operator()(OpKernelContext* context, const Tensor& input_tensor,
     88                   Tensor& output_tensor, int n, bool reverse) {
     89     const T* input = input_tensor.flat<T>().data();
     90     T* output = output_tensor.flat<T>().data();
     91 
     92     // Assume input_shape is [d1,d2,...dk], and output_shape is [d1,d2...dk-1],
     93     // then num_rows = d1*d2...dk-1, last_dim = dk.
     94     const int num_rows = output_tensor.NumElements();
     95     const int last_dim = input_tensor.dim_size(input_tensor.dims() - 1);
     96 
     97     // Allocate each row to different shard.
     98     auto SubNthElement = [&, input, output, last_dim, n](int start, int limit) {
     99       // std::nth_element would rearrange the array, so we need a new buffer.
    100       std::vector<T> buf(last_dim);
    101 
    102       for (int b = start; b < limit; ++b) {
    103         // Copy from one row of elements to buffer
    104         const T* input_start = input + b * last_dim;
    105         const T* input_end = input + (b + 1) * last_dim;
    106         std::copy(input_start, input_end, buf.begin());
    107 
    108         std::nth_element(buf.begin(), buf.begin() + n, buf.end());
    109         // The element placed in the nth position is exactly the element that
    110         // would occur in this position if the range was fully sorted.
    111         output[b] = buf[n];
    112       }
    113     };
    114 
    115     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
    116     // The average time complexity of partition-based nth_element (BFPRT) is
    117     // O(n), althought the worst time complexity could be O(n^2). Here, 20 is a
    118     // empirical factor of cost_per_unit.
    119     Shard(worker_threads.num_threads, worker_threads.workers, num_rows,
    120           20 * last_dim, SubNthElement);
    121   }
    122 };
    123 
    124 }  // namespace functor
    125 
    126 #define REGISTER_NTHOP(T)                                           \
    127   REGISTER_KERNEL_BUILDER(                                          \
    128       Name("NthElement").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    129       NthElementOp<CPUDevice, T>)
    130 
    131 TF_CALL_REAL_NUMBER_TYPES(REGISTER_NTHOP);
    132 #undef REGISTER_NTHOP
    133 
    134 }  // end namespace tensorflow
    135