1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/nn_ops.cc. 17 #include "tensorflow/core/kernels/nth_element_op.h" 18 19 #include <algorithm> 20 #include <iostream> 21 #include <vector> 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/register_types.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/framework/types.h" 26 #include "tensorflow/core/platform/logging.h" 27 #include "tensorflow/core/util/work_sharder.h" 28 29 namespace tensorflow { 30 31 typedef Eigen::ThreadPoolDevice CPUDevice; 32 33 template <typename Device, typename T> 34 class NthElementOp : public OpKernel { 35 public: 36 explicit NthElementOp(OpKernelConstruction* context) : OpKernel(context) { 37 OP_REQUIRES_OK(context, context->GetAttr("reverse", &reverse_)); 38 } 39 40 void Compute(OpKernelContext* context) override { 41 // The second args is N, which must be a positive scalar. 42 const auto& n_in = context->input(1); 43 OP_REQUIRES(context, TensorShapeUtils::IsScalar(n_in.shape()), 44 errors::InvalidArgument("N must be scalar, got shape ", 45 n_in.shape().DebugString())); 46 int n = n_in.scalar<int32>()(); 47 OP_REQUIRES(context, n >= 0, 48 errors::InvalidArgument("Need n >= 0, got ", n)); 49 50 // The first args is input tensor, which must have 1 dimension at least. 51 const Tensor& input_in = context->input(0); 52 const int num_dims = input_in.dims(); 53 OP_REQUIRES(context, num_dims >= 1, 54 errors::InvalidArgument("Input must be >= 1-D, got shape ", 55 input_in.shape().DebugString())); 56 // The last dimension of input tensor must be greater than N. 57 OP_REQUIRES( 58 context, input_in.dim_size(num_dims - 1) > n, 59 errors::InvalidArgument("Input must have at least n+1 columns")); 60 61 // std::nth_element only support the nth-smallest selection. 62 if (reverse_) { 63 n = input_in.dim_size(num_dims - 1) - n - 1; 64 } 65 66 // Assume input_shape is [d1,d2,...dk], and output_shape is [d1,d2...dk-1]. 67 TensorShape out_shape; 68 for (int i = 0; i < num_dims - 1; ++i) { 69 out_shape.AddDim(input_in.dim_size(i)); 70 } 71 Tensor* output_tensor = nullptr; 72 OP_REQUIRES_OK(context, 73 context->allocate_output(0, out_shape, &output_tensor)); 74 75 functor::NthElementFunctor<Device, T> nthElementFunc; 76 nthElementFunc(context, input_in, *output_tensor, n, reverse_); 77 } 78 79 private: 80 bool reverse_; 81 }; 82 83 namespace functor { 84 85 template <typename T> 86 struct NthElementFunctor<CPUDevice, T> { 87 void operator()(OpKernelContext* context, const Tensor& input_tensor, 88 Tensor& output_tensor, int n, bool reverse) { 89 const T* input = input_tensor.flat<T>().data(); 90 T* output = output_tensor.flat<T>().data(); 91 92 // Assume input_shape is [d1,d2,...dk], and output_shape is [d1,d2...dk-1], 93 // then num_rows = d1*d2...dk-1, last_dim = dk. 94 const int num_rows = output_tensor.NumElements(); 95 const int last_dim = input_tensor.dim_size(input_tensor.dims() - 1); 96 97 // Allocate each row to different shard. 98 auto SubNthElement = [&, input, output, last_dim, n](int start, int limit) { 99 // std::nth_element would rearrange the array, so we need a new buffer. 100 std::vector<T> buf(last_dim); 101 102 for (int b = start; b < limit; ++b) { 103 // Copy from one row of elements to buffer 104 const T* input_start = input + b * last_dim; 105 const T* input_end = input + (b + 1) * last_dim; 106 std::copy(input_start, input_end, buf.begin()); 107 108 std::nth_element(buf.begin(), buf.begin() + n, buf.end()); 109 // The element placed in the nth position is exactly the element that 110 // would occur in this position if the range was fully sorted. 111 output[b] = buf[n]; 112 } 113 }; 114 115 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 116 // The average time complexity of partition-based nth_element (BFPRT) is 117 // O(n), althought the worst time complexity could be O(n^2). Here, 20 is a 118 // empirical factor of cost_per_unit. 119 Shard(worker_threads.num_threads, worker_threads.workers, num_rows, 120 20 * last_dim, SubNthElement); 121 } 122 }; 123 124 } // namespace functor 125 126 #define REGISTER_NTHOP(T) \ 127 REGISTER_KERNEL_BUILDER( \ 128 Name("NthElement").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 129 NthElementOp<CPUDevice, T>) 130 131 TF_CALL_REAL_NUMBER_TYPES(REGISTER_NTHOP); 132 #undef REGISTER_NTHOP 133 134 } // end namespace tensorflow 135