1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_KERNELS_GATHER_FUNCTOR_H_ 17 #define TENSORFLOW_KERNELS_GATHER_FUNCTOR_H_ 18 19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20 21 #include "tensorflow/core/framework/op_kernel.h" 22 #include "tensorflow/core/framework/tensor_types.h" 23 #include "tensorflow/core/framework/type_traits.h" 24 #include "tensorflow/core/kernels/bounds_check.h" 25 #include "tensorflow/core/platform/prefetch.h" 26 #include "tensorflow/core/platform/types.h" 27 #include "tensorflow/core/util/work_sharder.h" 28 29 namespace tensorflow { 30 typedef Eigen::ThreadPoolDevice CPUDevice; 31 32 namespace functor { 33 34 // Helper method to copy using memcpy. 35 template <typename T, typename Index, typename SliceIndex, 36 SliceIndex static_slice_elems> 37 SliceIndex HandleCopies(OpKernelContext* ctx, 38 typename TTypes<T, 3>::ConstTensor params, 39 typename TTypes<Index>::ConstFlat indices, 40 SliceIndex slice_elems, 41 typename TTypes<T, 3>::Tensor out) { 42 const SliceIndex indices_size = static_cast<SliceIndex>(indices.dimension(0)); 43 const SliceIndex batch_size = static_cast<SliceIndex>(params.dimension(0)); 44 const Index limit = static_cast<Index>(params.dimension(1)); 45 T* out_base = &out(0, 0, 0); 46 const T* params_base = ¶ms(0, 0, 0); 47 if (static_slice_elems >= 0) { 48 // Give compiler static knowledge of the number of elements/bytes 49 slice_elems = static_slice_elems; 50 } 51 // Compute slice_bytes here so that static knowledge is available 52 const size_t slice_bytes = slice_elems * sizeof(T); 53 auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads(); 54 mutex mu; 55 // Store the value of invalidate index for printing error information, it's a 56 // shared variable. 57 SliceIndex result = -1; 58 auto work = [&](int64 start, int64 end) { 59 SliceIndex batch_idx = static_cast<SliceIndex>(start / indices_size); 60 SliceIndex indices_idx = static_cast<SliceIndex>(start % indices_size); 61 SliceIndex batch_idx_end = static_cast<SliceIndex>(end / indices_size); 62 SliceIndex indices_idx_end = static_cast<SliceIndex>(end % indices_size); 63 64 while ((batch_idx < batch_idx_end) || 65 (batch_idx == batch_idx_end && indices_idx < indices_idx_end)) { 66 SliceIndex i_next = indices_idx + 1; 67 SliceIndex b_next = batch_idx + 1; 68 if ((batch_idx == batch_idx_end && i_next < indices_idx_end) || 69 (i_next < indices_size)) { 70 port::prefetch<port::PREFETCH_HINT_T0>( 71 ¶ms(batch_idx, indices(i_next), 0)); 72 port::prefetch<port::PREFETCH_HINT_T0>(&out(batch_idx, i_next, 0)); 73 b_next = batch_idx; 74 } else if (b_next <= batch_idx_end) { 75 port::prefetch<port::PREFETCH_HINT_T0>(¶ms(b_next, indices(0), 0)); 76 port::prefetch<port::PREFETCH_HINT_T0>(&out(b_next, 0, 0)); 77 i_next = 0; 78 } 79 const Index index = internal::SubtleMustCopy(indices(indices_idx)); 80 if (!FastBoundsCheck(index, limit)) { 81 mutex_lock l(mu); 82 result = indices_idx; 83 return; 84 } 85 // Copy using memcpy if possible, otherwise an Eigen loop 86 // TODO(cwhipkey): avoid linking to framework to get Allocator (to improve 87 // ahead-of-time compilation binary size). 88 if (is_simple_type<T>::value) { 89 // Avoid auto-promotion to Index from SliceIndex by casting. 90 memcpy( 91 out_base + (batch_idx * indices_size + indices_idx) * slice_elems, 92 params_base + (batch_idx * static_cast<SliceIndex>(limit) + 93 static_cast<SliceIndex>(index)) * 94 slice_elems, 95 slice_bytes); 96 } else { 97 // For non-"simple" types (e.g. strings). 98 out.template chip<1>(indices_idx) = params.template chip<1>(index); 99 } 100 indices_idx = i_next; 101 batch_idx = b_next; 102 } 103 }; 104 105 Shard(worker_threads->num_threads, worker_threads->workers, 106 batch_size * indices_size, slice_elems * sizeof(T), work); 107 return result; 108 } 109 110 template <typename T, typename Index> 111 struct GatherFunctorCPU { 112 int64 operator()(OpKernelContext* ctx, 113 typename TTypes<T, 3>::ConstTensor params, 114 typename TTypes<Index>::ConstFlat indices, 115 typename TTypes<T, 3>::Tensor out) { 116 const int64 N = indices.size(); 117 const int64 slice_size = out.dimension(2); 118 int64 bad_i; 119 120 bool use_large = (slice_size > std::numeric_limits<int32>::max() || 121 params.size() > std::numeric_limits<int32>::max() || 122 N > std::numeric_limits<int32>::max()); 123 #define CALL(elems) \ 124 do { \ 125 if (use_large) { \ 126 bad_i = HandleCopies<T, Index, int64, elems>(ctx, params, indices, \ 127 slice_size, out); \ 128 } else { \ 129 const int32 small_slice = static_cast<int32>(slice_size); \ 130 bad_i = HandleCopies<T, Index, int32, elems>(ctx, params, indices, \ 131 small_slice, out); \ 132 } \ 133 } while (0) 134 135 if (slice_size == 10) 136 CALL(10); 137 else if (slice_size == 20) 138 CALL(20); 139 else 140 CALL(-1); 141 #undef CALL 142 143 return bad_i; 144 } 145 }; 146 147 template <typename Device, typename T, typename Index> 148 struct GatherFunctor { 149 int64 operator()(OpKernelContext* ctx, 150 typename TTypes<T, 3>::ConstTensor params, 151 typename TTypes<Index>::ConstFlat indices, 152 typename TTypes<T, 3>::Tensor out); 153 }; 154 155 template <typename T, typename Index> 156 struct GatherFunctor<CPUDevice, T, Index> { 157 int64 operator()(OpKernelContext* ctx, 158 typename TTypes<T, 3>::ConstTensor params, 159 typename TTypes<Index>::ConstFlat indices, 160 typename TTypes<T, 3>::Tensor out) { 161 return GatherFunctorCPU<T, Index>()(ctx, params, indices, out); 162 } 163 }; 164 165 } // namespace functor 166 } // namespace tensorflow 167 168 #endif // TENSORFLOW_KERNELS_GATHER_FUNCTOR_H_ 169