1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #include "tensorflow/contrib/rnn/kernels/gru_ops.h" 19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20 #include "tensorflow/core/framework/op_kernel.h" 21 22 namespace tensorflow { 23 24 typedef Eigen::ThreadPoolDevice CPUDevice; 25 typedef Eigen::GpuDevice GPUDevice; 26 27 template <typename Device, typename T, bool USE_CUBLAS> 28 class GRUCellBlockOp : public OpKernel { 29 public: 30 explicit GRUCellBlockOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 31 // TODO(gitegaurav) Replace the input checks with some smarter function. 32 void Compute(OpKernelContext* ctx) override { 33 // Grab the input tensors. 34 const Tensor* x_tensor = nullptr; 35 OP_REQUIRES_OK(ctx, ctx->input("x", &x_tensor)); 36 37 const Tensor* h_prev_tensor = nullptr; 38 OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor)); 39 40 const Tensor* w_ru_tensor = nullptr; 41 OP_REQUIRES_OK(ctx, ctx->input("w_ru", &w_ru_tensor)); 42 43 const Tensor* w_c_tensor = nullptr; 44 OP_REQUIRES_OK(ctx, ctx->input("w_c", &w_c_tensor)); 45 46 const Tensor* b_ru_tensor = nullptr; 47 OP_REQUIRES_OK(ctx, ctx->input("b_ru", &b_ru_tensor)); 48 49 const Tensor* b_c_tensor = nullptr; 50 OP_REQUIRES_OK(ctx, ctx->input("b_c", &b_c_tensor)); 51 52 const int64 batch_size = x_tensor->dim_size(0); 53 const int64 input_size = x_tensor->dim_size(1); 54 const int64 cell_size = h_prev_tensor->dim_size(1); 55 56 // Sanity checks for input shapes. 57 58 // Shape of 'h' must be [batch_size, cell_size] 59 OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size, 60 errors::InvalidArgument("h_prev.dims(0) != batch_size: ", 61 h_prev_tensor->dim_size(0), " vs. ", 62 batch_size)); 63 OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size, 64 errors::InvalidArgument( 65 "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1), 66 " vs. ", cell_size)); 67 68 // Shape of 'w_ru' must be [input_size+cell_size, 2*cell_size] 69 OP_REQUIRES(ctx, w_ru_tensor->dim_size(0) == input_size + cell_size, 70 errors::InvalidArgument( 71 "w_ru.dim_size(0) != input_size + cell_size: ", 72 w_ru_tensor->dim_size(0), " vs. ", input_size + cell_size)); 73 74 OP_REQUIRES(ctx, w_ru_tensor->dim_size(1) == cell_size * 2, 75 errors::InvalidArgument("w_ru.dim_size(1) != cell_size * 2: ", 76 w_ru_tensor->dim_size(1), " vs. ", 77 cell_size * 2)); 78 79 // Shape of 'w_c' must be [input_size+cell_size, cell_size] 80 OP_REQUIRES(ctx, w_c_tensor->dim_size(0) == input_size + cell_size, 81 errors::InvalidArgument( 82 "w_c.dim_size(0) != input_size + cell_size: ", 83 w_c_tensor->dim_size(0), " vs. ", input_size + cell_size)); 84 85 OP_REQUIRES(ctx, w_c_tensor->dim_size(1) == cell_size, 86 errors::InvalidArgument( 87 "w_c.dim_size(1) != cell_size: ", w_c_tensor->dim_size(1), 88 " vs. ", cell_size)); 89 90 // Shape of 'b_ru' must be [2*cell_size] 91 OP_REQUIRES(ctx, b_ru_tensor->dim_size(0) == cell_size * 2, 92 errors::InvalidArgument("b_ru.dim_size(0) != cell_size * 2: ", 93 b_ru_tensor->dim_size(0), " vs. ", 94 cell_size * 2)); 95 96 OP_REQUIRES(ctx, b_ru_tensor->dims() == 1, 97 errors::InvalidArgument("Rank of b_ru must be 1", 98 b_ru_tensor->dims(), " vs. 1", 1)); 99 // Shape of 'b_c' must be [cell_size] 100 OP_REQUIRES(ctx, b_c_tensor->dim_size(0) == cell_size, 101 errors::InvalidArgument( 102 "b_c.dim_size(0) != cell_size: ", b_c_tensor->dim_size(0), 103 " vs. ", cell_size)); 104 OP_REQUIRES(ctx, b_c_tensor->dims() == 1, 105 errors::InvalidArgument("Rank of b_c must be 1", 106 b_c_tensor->dims(), " vs. 1")); 107 108 // Create output tensors. 109 Tensor* r_tensor = nullptr; 110 OP_REQUIRES_OK( 111 ctx, ctx->allocate_output("r", TensorShape({batch_size, cell_size}), 112 &r_tensor)); 113 114 Tensor* u_tensor = nullptr; 115 OP_REQUIRES_OK( 116 ctx, ctx->allocate_output("u", TensorShape({batch_size, cell_size}), 117 &u_tensor)); 118 119 Tensor* c_tensor = nullptr; 120 OP_REQUIRES_OK( 121 ctx, ctx->allocate_output("c", TensorShape({batch_size, cell_size}), 122 &c_tensor)); 123 124 Tensor* h_tensor = nullptr; 125 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( 126 {"h_prev"}, "h", 127 TensorShape({batch_size, cell_size}), &h_tensor)); 128 129 // Allocate temp tensors. 130 Tensor x_h_prev_tensor; 131 OP_REQUIRES_OK(ctx, ctx->allocate_temp( 132 DataTypeToEnum<T>::v(), 133 TensorShape({batch_size, input_size + cell_size}), 134 &x_h_prev_tensor)); 135 136 Tensor x_h_prevr_tensor; 137 OP_REQUIRES_OK(ctx, ctx->allocate_temp( 138 DataTypeToEnum<T>::v(), 139 TensorShape({batch_size, input_size + cell_size}), 140 &x_h_prevr_tensor)); 141 142 Tensor r_u_bar_tensor; 143 OP_REQUIRES_OK(ctx, 144 ctx->allocate_temp(DataTypeToEnum<T>::v(), 145 TensorShape({batch_size, 2 * cell_size}), 146 &r_u_bar_tensor)); 147 148 const Device& device = ctx->eigen_device<Device>(); 149 150 functor::GRUBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size, 151 cell_size)( 152 ctx, device, x_tensor->matrix<T>(), h_prev_tensor->matrix<T>(), 153 w_ru_tensor->matrix<T>(), w_c_tensor->matrix<T>(), 154 b_ru_tensor->vec<T>(), b_c_tensor->vec<T>(), r_u_bar_tensor.matrix<T>(), 155 r_tensor->matrix<T>(), u_tensor->matrix<T>(), c_tensor->matrix<T>(), 156 h_tensor->matrix<T>(), x_h_prev_tensor.matrix<T>(), 157 x_h_prevr_tensor.matrix<T>()); 158 } 159 }; 160 161 // Register the Block GRU cell kernel for CPU. 162 #define REGISTER_KERNEL(T) \ 163 REGISTER_KERNEL_BUILDER( \ 164 Name("GRUBlockCell").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 165 GRUCellBlockOp<CPUDevice, T, false>); 166 167 REGISTER_KERNEL(float); 168 #undef REGISTER_KERNEL 169 170 template <typename Device, typename T, bool USE_CUBLAS> 171 class GRUBlockCellGradOp : public OpKernel { 172 public: 173 explicit GRUBlockCellGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 174 175 void Compute(OpKernelContext* ctx) override { 176 // Grab the input tensors. 177 const Tensor* x_tensor = nullptr; 178 OP_REQUIRES_OK(ctx, ctx->input("x", &x_tensor)); 179 180 const Tensor* h_prev_tensor = nullptr; 181 OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor)); 182 183 const Tensor* w_ru_tensor = nullptr; 184 OP_REQUIRES_OK(ctx, ctx->input("w_ru", &w_ru_tensor)); 185 186 const Tensor* w_c_tensor = nullptr; 187 OP_REQUIRES_OK(ctx, ctx->input("w_c", &w_c_tensor)); 188 189 const Tensor* b_ru_tensor = nullptr; 190 OP_REQUIRES_OK(ctx, ctx->input("b_ru", &b_ru_tensor)); 191 192 const Tensor* b_c_tensor = nullptr; 193 OP_REQUIRES_OK(ctx, ctx->input("b_c", &b_c_tensor)); 194 195 const Tensor* r_tensor = nullptr; 196 OP_REQUIRES_OK(ctx, ctx->input("r", &r_tensor)); 197 198 const Tensor* u_tensor = nullptr; 199 OP_REQUIRES_OK(ctx, ctx->input("u", &u_tensor)); 200 201 const Tensor* c_tensor = nullptr; 202 OP_REQUIRES_OK(ctx, ctx->input("c", &c_tensor)); 203 204 const Tensor* d_h_tensor = nullptr; 205 OP_REQUIRES_OK(ctx, ctx->input("d_h", &d_h_tensor)); 206 207 const int64 batch_size = x_tensor->dim_size(0); 208 const int64 input_size = x_tensor->dim_size(1); 209 const int64 cell_size = h_prev_tensor->dim_size(1); 210 211 // Sanity checks for input shapes. 212 213 // Shape of 'h_prev' must be [batch_size, cell_size] 214 OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size, 215 errors::InvalidArgument("h_prev.dims(0) != batch_size: ", 216 h_prev_tensor->dim_size(0), " vs. ", 217 batch_size)); 218 OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size, 219 errors::InvalidArgument( 220 "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1), 221 " vs. ", cell_size)); 222 223 // Shape of 'w_ru' must be [input_size+cell_size, 2*cell_size] 224 OP_REQUIRES(ctx, w_ru_tensor->dim_size(0) == input_size + cell_size, 225 errors::InvalidArgument( 226 "w_ru.dim_size(0) != input_size + cell_size: ", 227 w_ru_tensor->dim_size(0), " vs. ", input_size + cell_size)); 228 229 OP_REQUIRES(ctx, w_ru_tensor->dim_size(1) == cell_size * 2, 230 errors::InvalidArgument("w_ru.dim_size(1) != cell_size * 2: ", 231 w_ru_tensor->dim_size(1), " vs. ", 232 cell_size * 2)); 233 234 // Shape of 'w_c' must be [input_size+cell_size, cell_size] 235 OP_REQUIRES(ctx, w_c_tensor->dim_size(0) == input_size + cell_size, 236 errors::InvalidArgument( 237 "w_c.dim_size(0) != input_size + cell_size: ", 238 w_c_tensor->dim_size(0), " vs. ", input_size + cell_size)); 239 240 OP_REQUIRES(ctx, w_c_tensor->dim_size(1) == cell_size, 241 errors::InvalidArgument( 242 "w_c.dim_size(1) != cell_size: ", w_c_tensor->dim_size(1), 243 " vs. ", cell_size)); 244 245 // Shape of 'b_ru' must be [2*cell_size] 246 OP_REQUIRES(ctx, b_ru_tensor->dim_size(0) == cell_size * 2, 247 errors::InvalidArgument("b_ru.dim_size(0) != cell_size * 2: ", 248 b_ru_tensor->dim_size(0), " vs. ", 249 cell_size * 2)); 250 251 OP_REQUIRES(ctx, b_ru_tensor->dims() == 1, 252 errors::InvalidArgument("Rank of b_ru must be 1", 253 b_ru_tensor->dims(), " vs. 1")); 254 255 // Shape of 'b_c' must be [cell_size] 256 OP_REQUIRES(ctx, b_c_tensor->dim_size(0) == cell_size, 257 errors::InvalidArgument( 258 "b_c.dim_size(0) != cell_size: ", b_c_tensor->dim_size(0), 259 " vs. ", cell_size)); 260 261 OP_REQUIRES(ctx, b_c_tensor->dims() == 1, 262 errors::InvalidArgument("Rank of b_c must be 1 ", 263 b_c_tensor->dims(), " vs. 1")); 264 265 // Shape of 'r' must be [batch_size, cell_size] 266 OP_REQUIRES(ctx, r_tensor->dim_size(0) == batch_size, 267 errors::InvalidArgument( 268 "r.dims(0) != batch_size: ", r_tensor->dim_size(0), " vs. ", 269 batch_size)); 270 OP_REQUIRES(ctx, r_tensor->dim_size(1) == cell_size, 271 errors::InvalidArgument( 272 "r.dims(1) != cell_size: ", r_tensor->dim_size(1), " vs. ", 273 cell_size)); 274 275 // Shape of 'u' must be [batch_size, cell_size] 276 OP_REQUIRES(ctx, u_tensor->dim_size(0) == batch_size, 277 errors::InvalidArgument( 278 "u.dims(0) != batch_size: ", u_tensor->dim_size(0), " vs. ", 279 batch_size)); 280 OP_REQUIRES(ctx, u_tensor->dim_size(1) == cell_size, 281 errors::InvalidArgument( 282 "u.dims(1) != cell_size: ", u_tensor->dim_size(1), " vs. ", 283 cell_size)); 284 285 // Shape of 'c' must be [batch_size, cell_size] 286 OP_REQUIRES(ctx, c_tensor->dim_size(0) == batch_size, 287 errors::InvalidArgument( 288 "c.dims(0) != batch_size: ", c_tensor->dim_size(0), " vs. ", 289 batch_size)); 290 OP_REQUIRES(ctx, c_tensor->dim_size(1) == cell_size, 291 errors::InvalidArgument( 292 "c.dims(1) != cell_size: ", c_tensor->dim_size(1), " vs. ", 293 cell_size)); 294 295 // Shape of 'd_h' must be [batch_size, cell_size] 296 OP_REQUIRES(ctx, d_h_tensor->dim_size(0) == batch_size, 297 errors::InvalidArgument( 298 "d_h.dims(0) != batch_size: ", d_h_tensor->dim_size(0), 299 " vs. ", batch_size)); 300 OP_REQUIRES(ctx, d_h_tensor->dim_size(1) == cell_size, 301 errors::InvalidArgument( 302 "d_h.dims(1) != cell_size: ", d_h_tensor->dim_size(1), 303 " vs. ", cell_size)); 304 305 // Create output tensors. 306 Tensor* d_x_tensor = nullptr; 307 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( 308 {"x"}, "d_x", TensorShape({batch_size, input_size}), 309 &d_x_tensor)); 310 311 Tensor* d_h_prev_tensor = nullptr; 312 OP_REQUIRES_OK( 313 ctx, ctx->forward_input_or_allocate_output( 314 {"h_prev"}, "d_h_prev", TensorShape({batch_size, cell_size}), 315 &d_h_prev_tensor)); 316 317 Tensor* d_c_bar_tensor; 318 OP_REQUIRES_OK(ctx, ctx->allocate_output( 319 "d_c_bar", TensorShape({batch_size, cell_size}), 320 &d_c_bar_tensor)); 321 322 Tensor* d_r_bar_u_bar_tensor; 323 OP_REQUIRES_OK( 324 ctx, ctx->allocate_output("d_r_bar_u_bar", 325 TensorShape({batch_size, 2 * cell_size}), 326 &d_r_bar_u_bar_tensor)); 327 328 // Allocate temp tensors. 329 Tensor d_r_bar_tensor; 330 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 331 TensorShape({batch_size, cell_size}), 332 &d_r_bar_tensor)); 333 334 Tensor d_u_bar_tensor; 335 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 336 TensorShape({batch_size, cell_size}), 337 &d_u_bar_tensor)); 338 339 Tensor d_h_prevr_tensor; 340 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 341 TensorShape({batch_size, cell_size}), 342 &d_h_prevr_tensor)); 343 344 Tensor d_x_component_1_h_prev_compenent_1; 345 OP_REQUIRES_OK(ctx, ctx->allocate_temp( 346 DataTypeToEnum<T>::v(), 347 TensorShape({batch_size, input_size + cell_size}), 348 &d_x_component_1_h_prev_compenent_1)); 349 350 Tensor d_x_component_2_h_prevr; 351 OP_REQUIRES_OK(ctx, ctx->allocate_temp( 352 DataTypeToEnum<T>::v(), 353 TensorShape({batch_size, input_size + cell_size}), 354 &d_x_component_2_h_prevr)); 355 356 const Device& device = ctx->eigen_device<Device>(); 357 358 functor::GRUBlockCellBprop<Device, T, USE_CUBLAS>(batch_size, input_size, 359 cell_size)( 360 ctx, device, x_tensor->matrix<T>(), h_prev_tensor->matrix<T>(), 361 w_ru_tensor->matrix<T>(), w_c_tensor->matrix<T>(), 362 b_ru_tensor->vec<T>(), b_c_tensor->vec<T>(), r_tensor->matrix<T>(), 363 u_tensor->matrix<T>(), c_tensor->matrix<T>(), d_h_tensor->matrix<T>(), 364 d_x_tensor->matrix<T>(), d_h_prev_tensor->matrix<T>(), 365 d_c_bar_tensor->matrix<T>(), d_r_bar_u_bar_tensor->matrix<T>(), 366 d_r_bar_tensor.matrix<T>(), d_u_bar_tensor.matrix<T>(), 367 d_h_prevr_tensor.matrix<T>(), 368 d_x_component_1_h_prev_compenent_1.matrix<T>(), 369 d_x_component_2_h_prevr.matrix<T>()); 370 } 371 }; 372 373 // Register the gradient kernel for CPU. 374 #define REGISTER_KERNEL(T) \ 375 REGISTER_KERNEL_BUILDER( \ 376 Name("GRUBlockCellGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 377 GRUBlockCellGradOp<CPUDevice, T, false>); 378 379 REGISTER_KERNEL(float); 380 #undef REGISTER_KERNEL 381 382 // GPU support. 383 #if GOOGLE_CUDA 384 #define EIGEN_USE_GPU 385 386 // Forward declare the GPU Fprop functor. 387 namespace functor { 388 #define DECLARE_GPU_SPEC(T) \ 389 template <> \ 390 void GRUBlockCellFprop<GPUDevice, T, true>::operator()( \ 391 OpKernelContext* ctx, const GPUDevice& d, \ 392 typename TTypes<T>::ConstMatrix x, \ 393 typename TTypes<T>::ConstMatrix h_prev, \ 394 typename TTypes<T>::ConstMatrix w_ru, \ 395 typename TTypes<T>::ConstMatrix w_c, typename TTypes<T>::ConstVec b_ru, \ 396 typename TTypes<T>::ConstVec b_c, typename TTypes<T>::Matrix r_u_bar, \ 397 typename TTypes<T>::Matrix r, typename TTypes<T>::Matrix u, \ 398 typename TTypes<T>::Matrix c, typename TTypes<T>::Matrix h, \ 399 typename TTypes<T>::Matrix x_h_prev, \ 400 typename TTypes<T>::Matrix x_h_prevr); \ 401 extern template struct GRUBlockCellFprop<GPUDevice, T, true>; 402 403 DECLARE_GPU_SPEC(float); 404 #undef DECLARE_GPU_SPEC 405 } // end namespace functor 406 407 // Register the Block GRU cell kernel for GPU. 408 #define REGISTER_GPU_KERNEL(T) \ 409 REGISTER_KERNEL_BUILDER( \ 410 Name("GRUBlockCell").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 411 GRUCellBlockOp<GPUDevice, T, true>); 412 413 REGISTER_GPU_KERNEL(float); 414 #undef REGISTER_GPU_KERNEL 415 416 // Forward declare the GPU Bprop functor. 417 namespace functor { 418 #define DECLARE_GPU_SPEC(T) \ 419 template <> \ 420 void GRUBlockCellBprop<GPUDevice, T, true>::operator()( \ 421 OpKernelContext* ctx, const GPUDevice& d, \ 422 typename TTypes<T>::ConstMatrix x, typename TTypes<T>::ConstMatrix h, \ 423 typename TTypes<T>::ConstMatrix w_ru, \ 424 typename TTypes<T>::ConstMatrix w_c, typename TTypes<T>::ConstVec b_ru, \ 425 typename TTypes<T>::ConstVec b_c, typename TTypes<T>::ConstMatrix r, \ 426 typename TTypes<T>::ConstMatrix u, typename TTypes<T>::ConstMatrix c, \ 427 typename TTypes<T>::ConstMatrix d_h, typename TTypes<T>::Matrix d_x, \ 428 typename TTypes<T>::Matrix d_h_prev, typename TTypes<T>::Matrix d_c_bar, \ 429 typename TTypes<T>::Matrix d_r_bar_u_bar, \ 430 typename TTypes<T>::Matrix d_r_bar, typename TTypes<T>::Matrix d_u_bar, \ 431 typename TTypes<T>::Matrix d_h_prevr, \ 432 typename TTypes<T>::Matrix d_x_comp1_h_prev_comp1, \ 433 typename TTypes<T>::Matrix d_x_comp2_and_h_prevr); \ 434 extern template struct GRUBlockCellBprop<GPUDevice, T, true>; 435 436 DECLARE_GPU_SPEC(float); 437 #undef DECLARE_GPU_SPEC 438 } // end namespace functor 439 440 // Register the gradient kernel for GPU. 441 #define REGISTER_GPU_KERNEL(T) \ 442 REGISTER_KERNEL_BUILDER( \ 443 Name("GRUBlockCellGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 444 GRUBlockCellGradOp<GPUDevice, T, true>); 445 446 REGISTER_GPU_KERNEL(float); 447 #undef REGISTER_GPU_KERNEL 448 #endif // GOOGLE_CUDA 449 450 } // end namespace tensorflow 451