1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/array_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #if GOOGLE_CUDA 21 #define EIGEN_USE_GPU 22 #endif // GOOGLE_CUDA 23 24 #include "tensorflow/core/kernels/matrix_band_part_op.h" 25 26 #include <algorithm> 27 #include <memory> 28 #include <vector> 29 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 30 #include "tensorflow/core/framework/op_kernel.h" 31 #include "tensorflow/core/framework/register_types.h" 32 #include "tensorflow/core/framework/tensor.h" 33 #include "tensorflow/core/framework/tensor_shape.h" 34 #include "tensorflow/core/framework/tensor_types.h" 35 #include "tensorflow/core/framework/types.h" 36 #include "tensorflow/core/lib/core/threadpool.h" 37 #include "tensorflow/core/platform/logging.h" 38 #include "tensorflow/core/platform/macros.h" 39 40 namespace tensorflow { 41 42 typedef Eigen::ThreadPoolDevice CPUDevice; 43 typedef Eigen::GpuDevice GPUDevice; 44 45 template <typename Device, typename T> 46 class MatrixBandPartOp : public OpKernel { 47 public: 48 explicit MatrixBandPartOp(OpKernelConstruction* context) 49 : OpKernel(context) {} 50 51 void Compute(OpKernelContext* context) override { 52 const Tensor& input = context->input(0); 53 const TensorShape& input_shape = input.shape(); 54 // Preliminary validation of sizes. 55 OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape), 56 errors::InvalidArgument( 57 "input must be at least 2-dim, received shape: ", 58 input.shape().DebugString())); 59 auto input_reshaped = input.flat_inner_dims<T, 3>(); 60 61 const Tensor& num_lower_in = context->input(1); 62 OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_lower_in.shape()), 63 errors::InvalidArgument("num_lower must be scalar, got shape ", 64 num_lower_in.shape().DebugString())); 65 66 auto as_int64_scalar = [](const Tensor& tensor) -> int64 { 67 if (tensor.dtype() == DT_INT32) { 68 return tensor.scalar<int32>()(); 69 } else { 70 return tensor.scalar<int64>()(); 71 } 72 }; 73 const int64 num_lower = as_int64_scalar(num_lower_in); 74 OP_REQUIRES( 75 context, num_lower <= input_reshaped.dimension(1), 76 errors::InvalidArgument( 77 "num_lower must be negative or less or equal to number of rows (", 78 input_reshaped.dimension(1), ") got: ", num_lower)); 79 80 const Tensor& num_upper_in = context->input(2); 81 OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_upper_in.shape()), 82 errors::InvalidArgument("num_upper must be scalar, got shape ", 83 num_upper_in.shape().DebugString())); 84 const int64 num_upper = as_int64_scalar(num_upper_in); 85 OP_REQUIRES(context, num_upper <= input_reshaped.dimension(2), 86 errors::InvalidArgument("num_upper must be negative or less or " 87 "equal to number of columns (", 88 input_reshaped.dimension(2), 89 ") got: ", num_upper)); 90 91 if (input.NumElements() == 0 || 92 ((num_lower < 0 || num_lower == input_reshaped.dimension(1)) && 93 (num_upper < 0 || num_upper == input_reshaped.dimension(2)))) { 94 // This is a no-op. 95 context->set_output(0, input); 96 return; 97 } 98 99 Tensor* output = nullptr; 100 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( 101 {0}, 0, input_shape, &output)); 102 auto output_reshaped = output->flat_inner_dims<T, 3>(); 103 functor::MatrixBandPartFunctor<Device, T> fn; 104 fn(context, context->eigen_device<Device>(), num_lower, num_upper, 105 input_reshaped, output_reshaped); 106 } 107 108 private: 109 TF_DISALLOW_COPY_AND_ASSIGN(MatrixBandPartOp); 110 }; 111 112 #define REGISTER_MATRIX_BAND_PART(type) \ 113 REGISTER_KERNEL_BUILDER( \ 114 Name("MatrixBandPart").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 115 MatrixBandPartOp<CPUDevice, type>); 116 TF_CALL_POD_TYPES(REGISTER_MATRIX_BAND_PART); 117 #undef REGISTER_MATRIX_BAND_PART 118 119 // Registration of the deprecated kernel. 120 // Delete after 10mar2017. 121 #define REGISTER_BATCH_MATRIX_BAND_PART(type) \ 122 REGISTER_KERNEL_BUILDER(Name("BatchMatrixBandPart") \ 123 .Device(DEVICE_CPU) \ 124 .TypeConstraint<type>("T"), \ 125 MatrixBandPartOp<CPUDevice, type>); 126 TF_CALL_NUMBER_TYPES(REGISTER_BATCH_MATRIX_BAND_PART); 127 #undef REGISTER_BATCH_MATRIX_BAND_PART 128 129 // Implementation of the functor specialization for CPU. 130 namespace functor { 131 132 // CPU implementation of BandPartFunctor. 133 typedef Eigen::ThreadPoolDevice CPUDevice; 134 135 template <typename Scalar> 136 struct MatrixBandPartFunctor<CPUDevice, Scalar> { 137 void operator()(OpKernelContext* context, const CPUDevice& device, 138 int num_lower_diags, int num_upper_diags, 139 typename TTypes<Scalar, 3>::ConstTensor input, 140 typename TTypes<Scalar, 3>::Tensor output) { 141 const int64 b = input.dimension(0); 142 const int64 m = input.dimension(1); 143 const int64 n = input.dimension(2); 144 auto thread_pool = 145 context->device()->tensorflow_cpu_worker_threads()->workers; 146 const int64 total_rows = b * m; 147 const int64 row_cost = 10 * n; 148 const bool in_place = input.data() == output.data(); 149 auto compute_shard = [=, &input, &output](int64 begin, int64 end) { 150 if (!in_place) { 151 std::fill(output.data() + begin * n, output.data() + end * n, Scalar()); 152 } 153 const int64 batch_begin = begin / m; 154 const int64 batch_end = (end + m - 1) / m; 155 for (int64 batch = batch_begin; batch < batch_end; ++batch) { 156 const int64 row_begin = begin > batch * m ? begin % m : 0; 157 const int64 row_end = end < (batch + 1) * m ? end % m : m; 158 for (int64 row = row_begin; row < row_end; ++row) { 159 const int64 band_start = 160 num_lower_diags < 0 161 ? 0 162 : std::min(n, std::max(int64{0}, row - num_lower_diags)); 163 const int64 band_end = 164 num_upper_diags < 0 165 ? n 166 : std::min(static_cast<int64>(n), row + num_upper_diags + 1); 167 if (in_place) { 168 if (band_start > 0) { 169 std::fill(&output(batch, row, 0), &output(batch, row, band_start), 170 Scalar()); 171 } 172 if (band_end < n) { 173 std::fill(&output(batch, row, band_end), &output(batch, row, n), 174 Scalar()); 175 } 176 } else { 177 if (band_start < band_end) { 178 const Eigen::DSizes<Eigen::DenseIndex, 3> indices(batch, row, 179 band_start); 180 const Eigen::DSizes<Eigen::DenseIndex, 3> sizes( 181 1, 1, band_end - band_start); 182 output.slice(indices, sizes) = input.slice(indices, sizes); 183 } 184 } 185 } 186 } 187 }; 188 thread_pool->ParallelFor(total_rows, row_cost, std::move(compute_shard)); 189 } 190 }; 191 192 #define DEFINE_CPU_SPEC(T) template struct MatrixBandPartFunctor<CPUDevice, T>; 193 TF_CALL_POD_TYPES(DEFINE_CPU_SPEC); 194 #undef DEFINE_CPU_SPEC 195 196 } // namespace functor 197 198 #if GOOGLE_CUDA 199 200 // Forward declarations of the functor specializations for GPU. 201 namespace functor { 202 #define DECLARE_GPU_SPEC(T) \ 203 template <> \ 204 struct MatrixBandPartFunctor<GPUDevice, T> { \ 205 void operator()(OpKernelContext* context, const GPUDevice& device, \ 206 int num_upper_diags, int num_lower_diags, \ 207 typename TTypes<T, 3>::ConstTensor input, \ 208 typename TTypes<T, 3>::Tensor output); \ 209 }; \ 210 extern template struct MatrixBandPartFunctor<GPUDevice, T>; 211 212 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); 213 TF_CALL_bool(DECLARE_GPU_SPEC); 214 TF_CALL_complex64(DECLARE_GPU_SPEC); 215 TF_CALL_complex128(DECLARE_GPU_SPEC); 216 #undef DECLARE_GPU_SPEC 217 } // namespace functor 218 219 // Registration of the GPU implementations. 220 #define REGISTER_MATRIX_BAND_PART_GPU(type) \ 221 REGISTER_KERNEL_BUILDER(Name("MatrixBandPart") \ 222 .Device(DEVICE_GPU) \ 223 .TypeConstraint<type>("T") \ 224 .HostMemory("num_lower") \ 225 .HostMemory("num_upper"), \ 226 MatrixBandPartOp<GPUDevice, type>); 227 TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_BAND_PART_GPU); 228 TF_CALL_bool(REGISTER_MATRIX_BAND_PART_GPU); 229 TF_CALL_complex64(REGISTER_MATRIX_BAND_PART_GPU); 230 TF_CALL_complex128(REGISTER_MATRIX_BAND_PART_GPU); 231 #undef REGISTER_MATRIX_BAND_PART_GPU 232 233 // Registration of the deprecated kernel. 234 // Delete after 10mar2017. 235 #define REGISTER_BATCH_MATRIX_BAND_PART_GPU(type) \ 236 REGISTER_KERNEL_BUILDER(Name("BatchMatrixBandPart") \ 237 .Device(DEVICE_GPU) \ 238 .TypeConstraint<type>("T") \ 239 .HostMemory("num_lower") \ 240 .HostMemory("num_upper"), \ 241 MatrixBandPartOp<GPUDevice, type>); 242 TF_CALL_GPU_NUMBER_TYPES(REGISTER_BATCH_MATRIX_BAND_PART_GPU); 243 #undef REGISTER_BATCH_MATRIX_BAND_PART_GPU 244 245 #endif // GOOGLE_CUDA 246 247 } // namespace tensorflow 248