1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/math_ops.cc. 17 18 #include "tensorflow/core/kernels/bucketize_op.h" 19 #include "tensorflow/core/framework/op_kernel.h" 20 #include "tensorflow/core/framework/register_types.h" 21 #include "tensorflow/core/framework/tensor.h" 22 #include "tensorflow/core/framework/tensor_shape.h" 23 #include "tensorflow/core/platform/logging.h" 24 #include "tensorflow/core/platform/types.h" 25 26 namespace tensorflow { 27 28 using CPUDevice = Eigen::ThreadPoolDevice; 29 using GPUDevice = Eigen::GpuDevice; 30 31 namespace functor { 32 33 template <typename T> 34 struct BucketizeFunctor<CPUDevice, T> { 35 // PRECONDITION: boundaries_vector must be sorted. 36 static Status Compute(OpKernelContext* context, 37 const typename TTypes<T, 1>::ConstTensor& input, 38 const std::vector<float>& boundaries_vector, 39 typename TTypes<int32, 1>::Tensor& output) { 40 const int N = input.size(); 41 for (int i = 0; i < N; i++) { 42 auto first_bigger_it = std::upper_bound( 43 boundaries_vector.begin(), boundaries_vector.end(), input(i)); 44 output(i) = first_bigger_it - boundaries_vector.begin(); 45 } 46 47 return Status::OK(); 48 } 49 }; 50 51 } // namespace functor 52 53 template <typename Device, typename T> 54 class BucketizeOp : public OpKernel { 55 public: 56 explicit BucketizeOp(OpKernelConstruction* context) : OpKernel(context) { 57 OP_REQUIRES_OK(context, context->GetAttr("boundaries", &boundaries_)); 58 OP_REQUIRES(context, std::is_sorted(boundaries_.begin(), boundaries_.end()), 59 errors::InvalidArgument("Expected sorted boundaries")); 60 } 61 62 void Compute(OpKernelContext* context) override { 63 const Tensor& input_tensor = context->input(0); 64 const auto input = input_tensor.flat<T>(); 65 66 Tensor* output_tensor = nullptr; 67 OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), 68 &output_tensor)); 69 auto output = output_tensor->template flat<int32>(); 70 OP_REQUIRES_OK(context, functor::BucketizeFunctor<Device, T>::Compute( 71 context, input, boundaries_, output)); 72 } 73 74 private: 75 std::vector<float> boundaries_; 76 }; 77 78 #define REGISTER_KERNEL(T) \ 79 REGISTER_KERNEL_BUILDER( \ 80 Name("Bucketize").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 81 BucketizeOp<CPUDevice, T>); 82 83 REGISTER_KERNEL(int32); 84 REGISTER_KERNEL(int64); 85 REGISTER_KERNEL(float); 86 REGISTER_KERNEL(double); 87 #undef REGISTER_KERNEL 88 89 #if GOOGLE_CUDA 90 #define REGISTER_KERNEL(T) \ 91 REGISTER_KERNEL_BUILDER( \ 92 Name("Bucketize").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 93 BucketizeOp<GPUDevice, T>); 94 95 REGISTER_KERNEL(int32); 96 REGISTER_KERNEL(int64); 97 REGISTER_KERNEL(float); 98 REGISTER_KERNEL(double); 99 #undef REGISTER_KERNEL 100 #endif // GOOGLE_CUDA 101 102 } // namespace tensorflow 103