1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/array_ops.cc 17 18 #define EIGEN_USE_THREADS 19 20 #if GOOGLE_CUDA 21 #define EIGEN_USE_GPU 22 #endif // GOOGLE_CUDA 23 24 #include "tensorflow/core/kernels/diag_op.h" 25 26 #include <algorithm> 27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/framework/register_types.h" 30 #include "tensorflow/core/framework/tensor.h" 31 #include "tensorflow/core/framework/tensor_types.h" 32 #include "tensorflow/core/platform/logging.h" 33 #include "tensorflow/core/platform/types.h" 34 #include "tensorflow/core/util/work_sharder.h" 35 36 namespace tensorflow { 37 38 typedef Eigen::ThreadPoolDevice CPUDevice; 39 typedef Eigen::GpuDevice GPUDevice; 40 41 // Generate the diagonal tensor with the diagonal set to the input tensor. 42 template <typename Device, typename T> 43 class DiagOp : public OpKernel { 44 public: 45 explicit DiagOp(OpKernelConstruction* context) : OpKernel(context) {} 46 47 void Compute(OpKernelContext* context) override { 48 const Tensor& diagonal = context->input(0); 49 const int num_dims = diagonal.dims(); 50 OP_REQUIRES( 51 context, 0 != num_dims, 52 errors::InvalidArgument("Input must be at least rank 1, got 0")); 53 TensorShape out_shape; 54 for (int i = 0; i < num_dims; ++i) { 55 out_shape.AddDim(diagonal.dim_size(i)); 56 } 57 for (int i = 0; i < num_dims; ++i) { 58 out_shape.AddDim(diagonal.dim_size(i)); 59 } 60 Tensor* output_tensor = nullptr; 61 OP_REQUIRES_OK(context, 62 context->allocate_output(0, out_shape, &output_tensor)); 63 functor::DiagFunctor<Device, T> diagFunc; 64 Status s = 65 diagFunc(context, diagonal.NumElements(), diagonal.flat<T>().data(), 66 output_tensor->flat<T>().data()); 67 OP_REQUIRES_OK(context, s); 68 } 69 }; 70 71 // Extract the diagonal tensor with the diagonal set to the input tensor. 72 template <typename Device, typename T> 73 class DiagPartOp : public OpKernel { 74 public: 75 explicit DiagPartOp(OpKernelConstruction* context) : OpKernel(context) {} 76 77 void Compute(OpKernelContext* context) override { 78 const Tensor& tensor = context->input(0); 79 const int num_dims = tensor.dims(); 80 const int out_dims = num_dims / 2; 81 OP_REQUIRES(context, 0 == num_dims % 2, 82 errors::InvalidArgument("The rank of the tensor should be \ 83 even and positive, got shape ", 84 tensor.shape().DebugString())); 85 for (int i = 0; i < out_dims; i++) { 86 OP_REQUIRES( 87 context, tensor.dim_size(i) == tensor.dim_size(i + out_dims), 88 errors::InvalidArgument("Invalid shape ", 89 tensor.shape().DebugString(), ": dimensions ", 90 i, " and ", i + out_dims, " do not match.")); 91 } 92 93 TensorShape out_shape; 94 for (int i = 0; i < out_dims; ++i) { 95 out_shape.AddDim(tensor.dim_size(i)); 96 } 97 98 Tensor* output = nullptr; 99 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); 100 functor::DiagPartFunctor<Device, T> diagPartFunc; 101 Status s = diagPartFunc(context, out_shape.num_elements(), 102 tensor.flat<T>().data(), output->flat<T>().data()); 103 OP_REQUIRES_OK(context, s); 104 } 105 }; 106 107 // Implementation of the functor specialization for CPU. 108 // 109 // According to the diagonal definition, 110 // `output[i1,..., ik, i1,..., ik] = input[i1,..., ik]`, 111 // 112 // Let the rank of input is [s1,..., sk], then any offset of input's 113 // pointer can be represent by coordinate [i1,..., ik], 114 // where `index = i1*(s2*...*sk) + i2*(s3*...*sk) +... + ik` 115 // 116 // Let new_index is the offset of output's pointer with coordinate 117 // [i1,..., ik, i1,..., ik], then we have 118 // `new_index = i1*(s2*...sk*s1*...*sk) + i2*(s3*...*sk*s1*...*sk) +... + \ 119 // ik*(s1*...*sk) + i1*(s2*...*sk) + i2*(s3*...*sk) +... + ik 120 // = (i1*(s2*...*sk) + i2*(s3*...*sk) +... + ik) * (1 + s1*...*sk) 121 // = index * (1 + s1*...*sk) 122 // 123 // Let `size = s1*...*sk`, we finally have `new_index = index * (1 + size)`, 124 // which is the transfer function we use below. 125 // This trick make our implementations clear and easy to be parallel. 126 namespace functor { 127 template <typename T> 128 struct DiagFunctor<CPUDevice, T> { 129 EIGEN_ALWAYS_INLINE Status operator()(OpKernelContext* context, 130 const int64 size, const T* in, T* out) { 131 // This subprocess is responsible for writing values in index range 132 // [start*size, limit*size) 133 auto subDiag = [in, out, size](int64 start, int64 limit) { 134 std::fill(out + size * start, out + size * limit, T()); 135 for (int64 index = start; index < limit; ++index) { 136 out[(1 + size) * index] = in[index]; 137 } 138 }; 139 140 // Here, 5 is a empirical factor of cost_per_unit. 141 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 142 Shard(worker_threads.num_threads, worker_threads.workers, size, 5 * size, 143 subDiag); 144 return Status::OK(); 145 } 146 }; 147 148 template <typename T> 149 struct DiagPartFunctor<CPUDevice, T> { 150 EIGEN_ALWAYS_INLINE Status operator()(OpKernelContext* context, 151 const int64 size, const T* in, T* out) { 152 // This subprocess is responsible for extracting values in index range 153 // [start, limit) 154 auto subDiagPart = [in, out, size](int64 start, int64 limit) { 155 for (int64 index = start; index < limit; ++index) { 156 out[index] = in[(1 + size) * index]; 157 } 158 }; 159 160 // Here, 5 is a empirical factor of cost_per_unit. 161 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 162 Shard(worker_threads.num_threads, worker_threads.workers, size, 5, 163 subDiagPart); 164 return Status::OK(); 165 } 166 }; 167 } // namespace functor 168 169 // Register the CPU kernels. 170 #define REGISTER_DIAGOP(T) \ 171 REGISTER_KERNEL_BUILDER( \ 172 Name("Diag").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 173 DiagOp<CPUDevice, T>) 174 175 TF_CALL_double(REGISTER_DIAGOP); 176 TF_CALL_float(REGISTER_DIAGOP); 177 TF_CALL_int32(REGISTER_DIAGOP); 178 TF_CALL_int64(REGISTER_DIAGOP); 179 TF_CALL_complex64(REGISTER_DIAGOP); 180 TF_CALL_complex128(REGISTER_DIAGOP); 181 #undef REGISTER_DIAGOP 182 183 #define REGISTER_DIAGPARTOP(T) \ 184 REGISTER_KERNEL_BUILDER( \ 185 Name("DiagPart").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 186 DiagPartOp<CPUDevice, T>) 187 188 TF_CALL_double(REGISTER_DIAGPARTOP); 189 TF_CALL_float(REGISTER_DIAGPARTOP); 190 TF_CALL_int32(REGISTER_DIAGPARTOP); 191 TF_CALL_int64(REGISTER_DIAGPARTOP); 192 TF_CALL_complex64(REGISTER_DIAGPARTOP); 193 TF_CALL_complex128(REGISTER_DIAGPARTOP); 194 #undef REGISTER_DIAGPARTOP 195 196 // Register the GPU kernels. 197 #ifdef GOOGLE_CUDA 198 199 // Forward declarations of the functor specializations for GPU. 200 namespace functor { 201 extern template struct DiagFunctor<GPUDevice, double>; 202 extern template struct DiagFunctor<GPUDevice, float>; 203 extern template struct DiagFunctor<GPUDevice, int32>; 204 extern template struct DiagFunctor<GPUDevice, int64>; 205 extern template struct DiagFunctor<GPUDevice, complex64>; 206 extern template struct DiagFunctor<GPUDevice, complex128>; 207 } // namespace functor 208 209 #define REGISTER_DIAGOP_GPU(T) \ 210 REGISTER_KERNEL_BUILDER( \ 211 Name("Diag").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 212 DiagOp<GPUDevice, T>) 213 214 TF_CALL_double(REGISTER_DIAGOP_GPU); 215 TF_CALL_float(REGISTER_DIAGOP_GPU); 216 TF_CALL_int32(REGISTER_DIAGOP_GPU); 217 TF_CALL_int64(REGISTER_DIAGOP_GPU); 218 TF_CALL_complex64(REGISTER_DIAGOP_GPU); 219 TF_CALL_complex128(REGISTER_DIAGOP_GPU); 220 #undef REGISTER_DIAGOP_GPU 221 222 // Forward declarations of the functor specializations for GPU. 223 namespace functor { 224 extern template struct DiagPartFunctor<GPUDevice, double>; 225 extern template struct DiagPartFunctor<GPUDevice, float>; 226 extern template struct DiagPartFunctor<GPUDevice, int32>; 227 extern template struct DiagPartFunctor<GPUDevice, int64>; 228 extern template struct DiagPartFunctor<GPUDevice, complex64>; 229 extern template struct DiagPartFunctor<GPUDevice, complex128>; 230 } // namespace functor 231 232 #define REGISTER_DIAGPARTOP_GPU(T) \ 233 REGISTER_KERNEL_BUILDER( \ 234 Name("DiagPart").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 235 DiagPartOp<GPUDevice, T>) 236 237 TF_CALL_double(REGISTER_DIAGPARTOP_GPU); 238 TF_CALL_float(REGISTER_DIAGPARTOP_GPU); 239 TF_CALL_int32(REGISTER_DIAGPARTOP_GPU); 240 TF_CALL_int64(REGISTER_DIAGPARTOP_GPU); 241 TF_CALL_complex64(REGISTER_DIAGPARTOP_GPU); 242 TF_CALL_complex128(REGISTER_DIAGPARTOP_GPU); 243 #undef REGISTER_DIAGPARTOP_GPU 244 245 #endif // GOOGLE_CUDA 246 247 } // namespace tensorflow 248