1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_KERNELS_GATHER_ND_OP_CPU_IMPL_H_ 17 #define TENSORFLOW_KERNELS_GATHER_ND_OP_CPU_IMPL_H_ 18 19 // Specialization of GatherNdSlice to CPU 20 21 #define EIGEN_USE_THREADS 22 23 #include <atomic> 24 25 #include "tensorflow/core/framework/op_kernel.h" 26 #include "tensorflow/core/framework/register_types.h" 27 #include "tensorflow/core/framework/tensor.h" 28 #include "tensorflow/core/kernels/bounds_check.h" 29 #include "tensorflow/core/kernels/gather_nd_op.h" 30 #include "tensorflow/core/platform/logging.h" 31 #include "tensorflow/core/platform/mem.h" 32 #include "tensorflow/core/platform/types.h" 33 #include "tensorflow/core/util/util.h" 34 35 namespace tensorflow { 36 37 typedef Eigen::ThreadPoolDevice CPUDevice; 38 39 namespace generator { 40 41 template <typename T, typename Index, int IXDIM> 42 class GatherNdSliceGenerator { 43 public: 44 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE GatherNdSliceGenerator( 45 const Index slice_size, typename TTypes<Index>::ConstMatrix Tindices, 46 typename TTypes<T, IXDIM + 1>::ConstTensor Tparams, 47 typename TTypes<T>::Matrix Tout, std::atomic<Index>* error_loc) 48 : slice_size_(slice_size), 49 Tindices_(Tindices), 50 Tparams_(Tparams), 51 Tout_(Tout), 52 error_loc_(error_loc) {} 53 54 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool GenerateIndices( 55 const Index loc, Eigen::array<Eigen::DenseIndex, IXDIM + 1>* ix) const { 56 (*ix)[IXDIM] = 0; 57 bool out_of_bounds = false; 58 for (int i = 0; i < IXDIM; ++i) { 59 const Index ix_i = internal::SubtleMustCopy(Tindices_(loc, i)); 60 (*ix)[i] = ix_i; 61 out_of_bounds |= !FastBoundsCheck(ix_i, Tparams_.dimension(i)); 62 } 63 return out_of_bounds; 64 } 65 66 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int32 67 operator()(const Eigen::array<Eigen::DenseIndex, 1>& loc_array) const { 68 const Index loc = loc_array[0]; 69 Eigen::array<Eigen::DenseIndex, IXDIM + 1> ix; 70 Eigen::array<Eigen::DenseIndex, 2> ix_out; 71 ix_out[0] = loc; 72 ix_out[1] = 0; 73 const bool out_of_bounds = GenerateIndices(loc, &ix); 74 if (TF_PREDICT_FALSE(out_of_bounds)) { 75 error_loc_->store(loc); 76 std::fill_n(&Tout_(ix_out), slice_size_, T()); 77 } else { 78 std::copy_n(&Tparams_(ix), slice_size_, &Tout_(ix_out)); 79 } 80 81 return static_cast<int32>(0); // Return something... 82 } 83 84 private: 85 const Index slice_size_; 86 const typename TTypes<Index>::ConstMatrix Tindices_; 87 const typename TTypes<T, IXDIM + 1>::ConstTensor Tparams_; 88 mutable typename TTypes<T>::Matrix Tout_; 89 std::atomic<Index>* error_loc_; 90 }; 91 92 } // namespace generator 93 94 namespace functor { 95 96 template <typename T, typename Index, int IXDIM> 97 struct GatherNdSlice<CPUDevice, T, Index, IXDIM> { 98 Index operator()(const CPUDevice& d, const Index slice_size, 99 typename TTypes<int32>::Scalar Tscratch, 100 typename TTypes<T, IXDIM + 1>::ConstTensor Tparams, 101 typename TTypes<Index>::ConstMatrix Tindices, 102 typename TTypes<T>::Matrix Tout) { 103 std::atomic<Index> error_loc(-1); 104 105 const Eigen::DenseIndex batch_size = Tindices.dimension(0); 106 #if !defined(EIGEN_HAS_INDEX_LIST) 107 Eigen::Tensor<Eigen::DenseIndex, 1>::Dimensions reshape_dims{{ 1 }}; 108 Eigen::array<Eigen::DenseIndex, 1> broadcast_dims{{ batch_size }}; 109 #else 110 Eigen::IndexList<Eigen::type2index<1> > reshape_dims; 111 Eigen::IndexList<Eigen::DenseIndex> broadcast_dims; 112 broadcast_dims.set(0, batch_size); 113 #endif 114 generator::GatherNdSliceGenerator<T, Index, IXDIM> gather_nd_generator( 115 slice_size, Tindices, Tparams, Tout, &error_loc); 116 Tscratch.device(d) = Tscratch.reshape(reshape_dims) 117 .broadcast(broadcast_dims) 118 .generate(gather_nd_generator) 119 .sum(); 120 121 // error_loc() returns -1 if there's no out-of-bounds index, 122 // otherwise it returns the location of an OOB index in Tindices. 123 return error_loc.load(); 124 } 125 }; 126 127 #define REGISTER_GATHER_ND_FULL(T, Index) \ 128 template Index GatherNdSlice<CPUDevice, T, Index, CPU_PROVIDED_IXDIM>:: \ 129 operator()(const CPUDevice& d, const Index slice_size, \ 130 typename TTypes<int32>::Scalar Tscratch, \ 131 typename TTypes<T, CPU_PROVIDED_IXDIM + 1>::ConstTensor Tparams, \ 132 typename TTypes<Index>::ConstMatrix Tindices, \ 133 typename TTypes<T>::Matrix Tout); 134 135 #define REGISTER_GATHER_ND_CPU(type) \ 136 REGISTER_GATHER_ND_FULL(type, int32); \ 137 REGISTER_GATHER_ND_FULL(type, int64) 138 139 TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU); 140 141 } // namespace functor 142 143 } // namespace tensorflow 144 145 #endif // TENSORFLOW_KERNELS_GATHER_ND_OP_CPU_IMPL_H_ 146