1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #if GOOGLE_CUDA 19 #define EIGEN_USE_GPU 20 #endif // GOOGLE_CUDA 21 22 #include "tensorflow/contrib/rnn/kernels/lstm_ops.h" 23 24 #include <memory> 25 #include <vector> 26 27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/framework/register_types.h" 30 #include "tensorflow/core/framework/tensor.h" 31 #include "tensorflow/core/framework/tensor_shape.h" 32 #include "tensorflow/core/framework/tensor_types.h" 33 #include "tensorflow/core/framework/types.h" 34 #include "tensorflow/core/platform/logging.h" 35 #include "tensorflow/core/platform/macros.h" 36 37 namespace tensorflow { 38 39 typedef Eigen::ThreadPoolDevice CPUDevice; 40 typedef Eigen::GpuDevice GPUDevice; 41 42 namespace functor { 43 44 template <typename T> 45 void LSTMBlockCellFpropWithEigen( 46 const LSTMBlockCell& cell, OpKernelContext* ctx, const CPUDevice& d, 47 const T forget_bias, const T cell_clip, bool use_peephole, 48 typename TTypes<T>::ConstMatrix x, typename TTypes<T>::ConstMatrix cs_prev, 49 typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w, 50 typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf, 51 typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b, 52 typename TTypes<T>::Matrix xh, typename TTypes<T>::Matrix i, 53 typename TTypes<T>::Matrix cs, typename TTypes<T>::Matrix f, 54 typename TTypes<T>::Matrix o, typename TTypes<T>::Matrix ci, 55 typename TTypes<T>::Matrix co, typename TTypes<T>::Matrix icfo, 56 typename TTypes<T>::Matrix h) { 57 // Concat xh = [x, h]. 58 xh.slice(cell.xh_x_offsets(), cell.xh_x_extents()).device(d) = x; 59 xh.slice(cell.xh_h_offsets(), cell.xh_h_extents()).device(d) = h_prev; 60 61 // states1 = xh * w + b 62 typename TTypes<T>::ConstMatrix const_xh(xh.data(), xh.dimensions()); 63 TensorBlasGemm<CPUDevice, T, false /* USE_CUBLAS */>::compute( 64 ctx, d, false, false, T(1), const_xh, w, T(0), icfo); 65 Eigen::array<Eigen::DenseIndex, 2> b_shape({1, b.dimensions()[0]}); 66 Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({cell.batch_size(), 1}); 67 icfo.device(d) += b.reshape(b_shape).broadcast(broadcast_shape); 68 69 Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell.cell_size()}); 70 Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({cell.batch_size(), 1}); 71 72 // Input gate. 73 if (use_peephole) { 74 auto i_peep = cs_prev * wci.reshape(p_shape).broadcast(p_broadcast_shape); 75 i.device(d) = 76 (icfo.slice(cell.icfo_i_offsets(), cell.cell_extents()) + i_peep) 77 .sigmoid(); 78 } else { 79 i.device(d) = 80 icfo.slice(cell.icfo_i_offsets(), cell.cell_extents()).sigmoid(); 81 } 82 83 // Cell input. 84 ci.device(d) = icfo.slice(cell.icfo_c_offsets(), cell.cell_extents()).tanh(); 85 86 // Forget gate (w/ bias). 87 if (use_peephole) { 88 auto f_peep = cs_prev * wcf.reshape(p_shape).broadcast(p_broadcast_shape); 89 f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) + 90 f.constant(forget_bias) + f_peep) 91 .sigmoid(); 92 } else { 93 f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) + 94 f.constant(forget_bias)) 95 .sigmoid(); 96 } 97 98 // cs = ci .* i + f .* cs_prev 99 cs.device(d) = i * ci + f * cs_prev; 100 101 if (cell_clip > 0.0f) { 102 cs.device(d) = 103 cs.binaryExpr(cs.constant(cell_clip), Eigen::scalar_clip_op<T>()); 104 } 105 106 // co = tanh(cs) 107 co.device(d) = cs.tanh(); 108 109 // Output gate. 110 if (use_peephole) { 111 auto o_peep = cs * wco.reshape(p_shape).broadcast(p_broadcast_shape); 112 o.device(d) = 113 (icfo.slice(cell.icfo_o_offsets(), cell.cell_extents()) + o_peep) 114 .sigmoid(); 115 } else { 116 o.device(d) = 117 icfo.slice(cell.icfo_o_offsets(), cell.cell_extents()).sigmoid(); 118 } 119 120 // h = o .* co 121 h.device(d) = o * co; 122 } 123 124 template <typename Device, typename T, bool USE_CUBLAS> 125 void LSTMBlockCellBpropWithEigen( 126 const LSTMBlockCell& cell, OpKernelContext* ctx, const Device& d, 127 bool use_peephole, typename TTypes<T>::ConstMatrix x, 128 typename TTypes<T>::ConstMatrix cs_prev, 129 typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w, 130 typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf, 131 typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b, 132 typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs, 133 typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o, 134 typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co, 135 typename TTypes<T>::ConstMatrix cs_grad, 136 typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, 137 typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, 138 typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, 139 typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad, 140 typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, 141 typename TTypes<T>::Vec wco_grad) { 142 // do[t] = sigm'(o[t]) .* dh[t] .* co[t] 143 do_.device(d) = o * (o.constant(T(1)) - o) * h_grad * co; 144 145 // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1] 146 dcs.device(d) = (co.constant(T(1)) - co * co) * h_grad * o + cs_grad; 147 148 Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell.cell_size()}); 149 Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({cell.batch_size(), 1}); 150 if (use_peephole) { 151 dcs.device(d) = 152 dcs + do_ * wco.reshape(p_shape).broadcast(p_broadcast_shape); 153 } 154 155 // dci[t] = tanh'(ci[t]) dcs[t] i[t] 156 dci.device(d) = (ci.constant(T(1)) - ci * ci) * dcs * i; 157 158 // df[t] = sigm'(f[t]) dcs[t] cs[t - 1] 159 df.device(d) = f * (f.constant(T(1)) - f) * dcs * cs_prev; 160 161 // di[t] = sigm'(i[t]) dcs[t] ci[t] 162 di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci; 163 164 dicfo.slice(cell.icfo_i_offsets(), cell.cell_extents()).device(d) = di; 165 dicfo.slice(cell.icfo_c_offsets(), cell.cell_extents()).device(d) = dci; 166 dicfo.slice(cell.icfo_f_offsets(), cell.cell_extents()).device(d) = df; 167 dicfo.slice(cell.icfo_o_offsets(), cell.cell_extents()).device(d) = do_; 168 169 cs_prev_grad.device(d) = dcs * f; 170 if (use_peephole) { 171 cs_prev_grad.device(d) = 172 cs_prev_grad + di * wci.reshape(p_shape).broadcast(p_broadcast_shape) + 173 df * wcf.reshape(p_shape).broadcast(p_broadcast_shape); 174 wci_grad.device(d) = (di * cs_prev).sum(Eigen::array<int, 1>({0})); 175 wcf_grad.device(d) = (df * cs_prev).sum(Eigen::array<int, 1>({0})); 176 wco_grad.device(d) = (do_ * cs).sum(Eigen::array<int, 1>({0})); 177 } 178 } 179 180 #define DEFINE_CPU_SPECS(T) \ 181 template <> \ 182 void LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */>::operator()( \ 183 OpKernelContext* ctx, const CPUDevice& d, const T forget_bias, \ 184 const T cell_clip, bool use_peephole, typename TTypes<T>::ConstMatrix x, \ 185 typename TTypes<T>::ConstMatrix cs_prev, \ 186 typename TTypes<T>::ConstMatrix h_prev, \ 187 typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \ 188 typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \ 189 typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \ 190 typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, \ 191 typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, \ 192 typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \ 193 typename TTypes<T>::Matrix icfo, typename TTypes<T>::Matrix h) { \ 194 LSTMBlockCellFpropWithEigen<T>( \ 195 *this, ctx, d, forget_bias, cell_clip, use_peephole, x, cs_prev, \ 196 h_prev, w, wci, wcf, wco, b, xh, i, cs, f, o, ci, co, icfo, h); \ 197 } \ 198 template <> \ 199 void LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */>::operator()( \ 200 OpKernelContext* ctx, const CPUDevice& d, bool use_peephole, \ 201 typename TTypes<T>::ConstMatrix x, \ 202 typename TTypes<T>::ConstMatrix cs_prev, \ 203 typename TTypes<T>::ConstMatrix h_prev, \ 204 typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \ 205 typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \ 206 typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i, \ 207 typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f, \ 208 typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci, \ 209 typename TTypes<T>::ConstMatrix co, \ 210 typename TTypes<T>::ConstMatrix cs_grad, \ 211 typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \ 212 typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, \ 213 typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, \ 214 typename TTypes<T>::Matrix dicfo, \ 215 typename TTypes<T>::Matrix cs_prev_grad, \ 216 typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, \ 217 typename TTypes<T>::Vec wco_grad) { \ 218 LSTMBlockCellBpropWithEigen<CPUDevice, T, false /* USE_CUBLAS */>( \ 219 *this, ctx, d, use_peephole, x, cs_prev, h_prev, w, wci, wcf, wco, b, \ 220 i, cs, f, o, ci, co, cs_grad, h_grad, do_, dcs, dci, df, di, dicfo, \ 221 cs_prev_grad, wci_grad, wcf_grad, wco_grad); \ 222 } \ 223 template struct LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */>; \ 224 template struct LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */>; 225 226 DEFINE_CPU_SPECS(float); 227 #undef DEFINE_CPU_SPECS 228 229 } // namespace functor 230 231 template <typename Device, typename T, bool USE_CUBLAS> 232 class LSTMBlockCellOp : public OpKernel { 233 public: 234 explicit LSTMBlockCellOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 235 OP_REQUIRES_OK(ctx, ctx->GetAttr("forget_bias", &forget_bias_)); 236 OP_REQUIRES_OK(ctx, ctx->GetAttr("cell_clip", &cell_clip_)); 237 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_)); 238 } 239 240 void Compute(OpKernelContext* ctx) override { 241 const Tensor* x_tensor = nullptr; 242 OP_REQUIRES_OK(ctx, ctx->input("x", &x_tensor)); 243 244 const Tensor* cs_prev_tensor = nullptr; 245 OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor)); 246 247 const Tensor* h_prev_tensor = nullptr; 248 OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor)); 249 250 const Tensor* w_tensor = nullptr; 251 OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor)); 252 253 const Tensor* wci_tensor = nullptr; 254 OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor)); 255 256 const Tensor* wcf_tensor = nullptr; 257 OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor)); 258 259 const Tensor* wco_tensor = nullptr; 260 OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor)); 261 262 const Tensor* b_tensor = nullptr; 263 OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor)); 264 265 const int64 batch_size = x_tensor->dim_size(0); 266 const int64 input_size = x_tensor->dim_size(1); 267 const int64 cell_size = cs_prev_tensor->dim_size(1); 268 269 // Sanity checks for our input shapes. 270 OP_REQUIRES(ctx, cs_prev_tensor->dim_size(0) == batch_size, 271 errors::InvalidArgument("cs_prev.dims(0) != batch_size: ", 272 cs_prev_tensor->dim_size(0), " vs. ", 273 batch_size)); 274 OP_REQUIRES(ctx, cs_prev_tensor->dim_size(1) == cell_size, 275 errors::InvalidArgument("cs_prev.dims(1) != cell_size: ", 276 cs_prev_tensor->dim_size(1), " vs. ", 277 cell_size)); 278 279 OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size, 280 errors::InvalidArgument("h_prev.dims(0) != batch_size: ", 281 h_prev_tensor->dim_size(0), " vs. ", 282 batch_size)); 283 OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size, 284 errors::InvalidArgument( 285 "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1), 286 " vs. ", cell_size)); 287 288 OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size, 289 errors::InvalidArgument( 290 "w.dim_size(0) != input_size + cell_size: ", 291 w_tensor->dim_size(0), " vs. ", input_size + cell_size)); 292 OP_REQUIRES(ctx, w_tensor->dim_size(1) == cell_size * 4, 293 errors::InvalidArgument( 294 "w.dim_size(1) != cell_size * 4: ", w_tensor->dim_size(1), 295 " vs. ", cell_size * 4)); 296 297 OP_REQUIRES(ctx, b_tensor->dim_size(0) == cell_size * 4, 298 errors::InvalidArgument( 299 "b.dim_size(0) != cell_size * 4: ", b_tensor->dim_size(0), 300 " vs. ", cell_size * 4)); 301 302 // Allocate our output tensors. 303 Tensor* i_tensor = nullptr; 304 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( 305 {"h_prev"}, "i", 306 TensorShape({batch_size, cell_size}), &i_tensor)); 307 308 Tensor* cs_tensor = nullptr; 309 OP_REQUIRES_OK( 310 ctx, ctx->allocate_output("cs", TensorShape({batch_size, cell_size}), 311 &cs_tensor)); 312 313 Tensor* f_tensor = nullptr; 314 OP_REQUIRES_OK( 315 ctx, ctx->allocate_output("f", TensorShape({batch_size, cell_size}), 316 &f_tensor)); 317 318 Tensor* o_tensor = nullptr; 319 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( 320 {"cs_prev"}, "o", 321 TensorShape({batch_size, cell_size}), &o_tensor)); 322 323 Tensor* ci_tensor = nullptr; 324 OP_REQUIRES_OK( 325 ctx, ctx->allocate_output("ci", TensorShape({batch_size, cell_size}), 326 &ci_tensor)); 327 328 Tensor* co_tensor = nullptr; 329 OP_REQUIRES_OK( 330 ctx, ctx->allocate_output("co", TensorShape({batch_size, cell_size}), 331 &co_tensor)); 332 333 Tensor* h_tensor = nullptr; 334 OP_REQUIRES_OK( 335 ctx, ctx->allocate_output("h", TensorShape({batch_size, cell_size}), 336 &h_tensor)); 337 338 // Allocate our temp tensors. 339 Tensor xh_tensor; 340 OP_REQUIRES_OK(ctx, ctx->allocate_temp( 341 DataTypeToEnum<T>::v(), 342 TensorShape({batch_size, input_size + cell_size}), 343 &xh_tensor)); 344 345 Tensor icfo_tensor; 346 OP_REQUIRES_OK(ctx, 347 ctx->allocate_temp(DataTypeToEnum<T>::v(), 348 TensorShape({batch_size, cell_size * 4}), 349 &icfo_tensor)); 350 351 const Device& device = ctx->eigen_device<Device>(); 352 353 functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size, 354 cell_size)( 355 ctx, device, forget_bias_, cell_clip_, use_peephole_, 356 x_tensor->matrix<T>(), cs_prev_tensor->matrix<T>(), 357 h_prev_tensor->matrix<T>(), w_tensor->matrix<T>(), wci_tensor->vec<T>(), 358 wcf_tensor->vec<T>(), wco_tensor->vec<T>(), b_tensor->vec<T>(), 359 xh_tensor.matrix<T>(), i_tensor->matrix<T>(), cs_tensor->matrix<T>(), 360 f_tensor->matrix<T>(), o_tensor->matrix<T>(), ci_tensor->matrix<T>(), 361 co_tensor->matrix<T>(), icfo_tensor.matrix<T>(), h_tensor->matrix<T>()); 362 } 363 364 private: 365 float forget_bias_; 366 float cell_clip_; 367 bool use_peephole_; 368 }; 369 370 #define REGISTER_KERNEL(T) \ 371 REGISTER_KERNEL_BUILDER( \ 372 Name("LSTMBlockCell").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 373 LSTMBlockCellOp<CPUDevice, T, false>); 374 REGISTER_KERNEL(float); 375 // REGISTER_KERNEL(double); 376 #undef REGISTER_KERNEL 377 378 #if GOOGLE_CUDA 379 namespace functor { 380 #define DECLARE_GPU_SPEC(T) \ 381 template <> \ 382 void LSTMBlockCellFprop<GPUDevice, T, true>::operator()( \ 383 OpKernelContext* ctx, const GPUDevice& d, const T forget_bias, \ 384 const T cell_clip, bool use_peephole, typename TTypes<T>::ConstMatrix x, \ 385 typename TTypes<T>::ConstMatrix cs_prev, \ 386 typename TTypes<T>::ConstMatrix h_prev, \ 387 typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \ 388 typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \ 389 typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \ 390 typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, \ 391 typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, \ 392 typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \ 393 typename TTypes<T>::Matrix icfo, typename TTypes<T>::Matrix h); \ 394 \ 395 extern template struct LSTMBlockCellFprop<GPUDevice, T, true>; 396 397 DECLARE_GPU_SPEC(float); 398 // DECLARE_GPU_SPEC(double); 399 #undef DECLARE_GPU_SPEC 400 } // end namespace functor 401 402 #define REGISTER_GPU_KERNEL(T) \ 403 REGISTER_KERNEL_BUILDER( \ 404 Name("LSTMBlockCell").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 405 LSTMBlockCellOp<GPUDevice, T, true>); 406 407 REGISTER_GPU_KERNEL(float); 408 // REGISTER_GPU_KERNEL(double); 409 #undef REGISTER_GPU_KERNEL 410 #endif // GOOGLE_CUDA 411 412 template <typename Device, typename T, bool USE_CUBLAS> 413 class LSTMBlockCellGradOp : public OpKernel { 414 public: 415 explicit LSTMBlockCellGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 416 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_)); 417 } 418 419 void Compute(OpKernelContext* ctx) override { 420 const Tensor* x_tensor = nullptr; 421 OP_REQUIRES_OK(ctx, ctx->input("x", &x_tensor)); 422 423 const Tensor* cs_prev_tensor = nullptr; 424 OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor)); 425 426 const Tensor* h_prev_tensor = nullptr; 427 OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor)); 428 429 const Tensor* w_tensor = nullptr; 430 OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor)); 431 432 const Tensor* wci_tensor = nullptr; 433 OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor)); 434 435 const Tensor* wcf_tensor = nullptr; 436 OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor)); 437 438 const Tensor* wco_tensor = nullptr; 439 OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor)); 440 441 const Tensor* b_tensor = nullptr; 442 OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor)); 443 444 const Tensor* i_tensor = nullptr; 445 OP_REQUIRES_OK(ctx, ctx->input("i", &i_tensor)); 446 447 const Tensor* cs_tensor = nullptr; 448 OP_REQUIRES_OK(ctx, ctx->input("cs", &cs_tensor)); 449 450 const Tensor* f_tensor = nullptr; 451 OP_REQUIRES_OK(ctx, ctx->input("f", &f_tensor)); 452 453 const Tensor* o_tensor = nullptr; 454 OP_REQUIRES_OK(ctx, ctx->input("o", &o_tensor)); 455 456 const Tensor* ci_tensor = nullptr; 457 OP_REQUIRES_OK(ctx, ctx->input("ci", &ci_tensor)); 458 459 const Tensor* co_tensor = nullptr; 460 OP_REQUIRES_OK(ctx, ctx->input("co", &co_tensor)); 461 462 const Tensor* cs_grad_tensor = nullptr; 463 OP_REQUIRES_OK(ctx, ctx->input("cs_grad", &cs_grad_tensor)); 464 465 const Tensor* h_grad_tensor = nullptr; 466 OP_REQUIRES_OK(ctx, ctx->input("h_grad", &h_grad_tensor)); 467 468 const int64 batch_size = x_tensor->dim_size(0); 469 const int64 input_size = x_tensor->dim_size(1); 470 const int64 cell_size = cs_prev_tensor->dim_size(1); 471 472 // Sanity checks for our input shapes. 473 OP_REQUIRES(ctx, cs_prev_tensor->dim_size(0) == batch_size, 474 errors::InvalidArgument("cs_prev.dims(0) != batch_size: ", 475 cs_prev_tensor->dim_size(0), " vs. ", 476 batch_size)); 477 OP_REQUIRES(ctx, cs_prev_tensor->dim_size(1) == cell_size, 478 errors::InvalidArgument("cs_prev.dims(1) != cell_size: ", 479 cs_prev_tensor->dim_size(1), " vs. ", 480 cell_size)); 481 482 OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size, 483 errors::InvalidArgument("h_prev.dims(0) != batch_size: ", 484 h_prev_tensor->dim_size(0), " vs. ", 485 batch_size)); 486 OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size, 487 errors::InvalidArgument( 488 "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1), 489 " vs. ", cell_size)); 490 491 OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size, 492 errors::InvalidArgument( 493 "w.dim_size(0) != input_size + cell_size: ", 494 w_tensor->dim_size(0), " vs. ", input_size + cell_size)); 495 OP_REQUIRES(ctx, w_tensor->dim_size(1) == cell_size * 4, 496 errors::InvalidArgument( 497 "w.dim_size(1) != cell_size * 4: ", w_tensor->dim_size(1), 498 " vs. ", cell_size * 4)); 499 500 OP_REQUIRES(ctx, b_tensor->dim_size(0) == cell_size * 4, 501 errors::InvalidArgument( 502 "b.dim_size(0) != cell_size * 4: ", b_tensor->dim_size(0), 503 " vs. ", cell_size * 4)); 504 505 OP_REQUIRES(ctx, i_tensor->dim_size(0) == batch_size, 506 errors::InvalidArgument( 507 "i.dim_size(0) != batch_size: ", i_tensor->dim_size(0), 508 " vs. ", batch_size)); 509 OP_REQUIRES(ctx, i_tensor->dim_size(1) == cell_size, 510 errors::InvalidArgument( 511 "i.dim_size(1) != cell_size: ", i_tensor->dim_size(1), 512 " vs. ", cell_size)); 513 514 OP_REQUIRES(ctx, cs_tensor->dim_size(0) == batch_size, 515 errors::InvalidArgument( 516 "cs.dim_size(0) != batch_size: ", cs_tensor->dim_size(0), 517 " vs. ", batch_size)); 518 OP_REQUIRES(ctx, cs_tensor->dim_size(1) == cell_size, 519 errors::InvalidArgument( 520 "cs.dim_size(1) != cell_size: ", cs_tensor->dim_size(1), 521 " vs. ", cell_size)); 522 523 OP_REQUIRES(ctx, f_tensor->dim_size(0) == batch_size, 524 errors::InvalidArgument( 525 "f.dim_size(0) != batch_size: ", f_tensor->dim_size(0), 526 " vs. ", batch_size)); 527 OP_REQUIRES(ctx, f_tensor->dim_size(1) == cell_size, 528 errors::InvalidArgument( 529 "i.dim_size(1) != cell_size: ", f_tensor->dim_size(1), 530 " vs. ", cell_size)); 531 532 OP_REQUIRES(ctx, o_tensor->dim_size(0) == batch_size, 533 errors::InvalidArgument( 534 "o.dim_size(0) != batch_size: ", o_tensor->dim_size(0), 535 " vs. ", batch_size)); 536 OP_REQUIRES(ctx, o_tensor->dim_size(1) == cell_size, 537 errors::InvalidArgument( 538 "o.dim_size(1) != cell_size: ", o_tensor->dim_size(1), 539 " vs. ", cell_size)); 540 541 OP_REQUIRES(ctx, ci_tensor->dim_size(0) == batch_size, 542 errors::InvalidArgument( 543 "ci.dim_size(0) != batch_size: ", ci_tensor->dim_size(0), 544 " vs. ", batch_size)); 545 OP_REQUIRES(ctx, ci_tensor->dim_size(1) == cell_size, 546 errors::InvalidArgument( 547 "ci.dim_size(1) != cell_size: ", ci_tensor->dim_size(1), 548 " vs. ", cell_size)); 549 550 OP_REQUIRES(ctx, co_tensor->dim_size(0) == batch_size, 551 errors::InvalidArgument( 552 "co.dim_size(0) != batch_size: ", co_tensor->dim_size(0), 553 " vs. ", batch_size)); 554 OP_REQUIRES(ctx, co_tensor->dim_size(1) == cell_size, 555 errors::InvalidArgument( 556 "co.dim_size(1) != cell_size: ", co_tensor->dim_size(1), 557 " vs. ", cell_size)); 558 559 OP_REQUIRES(ctx, cs_grad_tensor->dim_size(0) == batch_size, 560 errors::InvalidArgument( 561 "cs_grad_tensor.dims(0) != batch_size: ", 562 cs_grad_tensor->dim_size(0), " vs. ", batch_size)); 563 OP_REQUIRES(ctx, cs_grad_tensor->dim_size(1) == cell_size, 564 errors::InvalidArgument("cs_grad_tensor.dims(1) != cell_size: ", 565 cs_grad_tensor->dim_size(1), " vs. ", 566 cell_size)); 567 568 OP_REQUIRES(ctx, h_grad_tensor->dim_size(0) == batch_size, 569 errors::InvalidArgument("h_grad_tensor.dims(0) != batch_size: ", 570 h_grad_tensor->dim_size(0), " vs. ", 571 batch_size)); 572 OP_REQUIRES(ctx, h_grad_tensor->dim_size(1) == cell_size, 573 errors::InvalidArgument("h_grad_tensor.dims(1) != cell_size: ", 574 h_grad_tensor->dim_size(1), " vs. ", 575 cell_size)); 576 577 // Allocate our output tensors. 578 Tensor* cs_prev_grad_tensor = nullptr; 579 OP_REQUIRES_OK( 580 ctx, ctx->forward_input_or_allocate_output( 581 {"cs_grad"}, "cs_prev_grad", 582 TensorShape({batch_size, cell_size}), &cs_prev_grad_tensor)); 583 584 Tensor* dicfo_tensor = nullptr; 585 OP_REQUIRES_OK(ctx, ctx->allocate_output( 586 "dicfo", TensorShape({batch_size, cell_size * 4}), 587 &dicfo_tensor)); 588 589 Tensor* wci_grad_tensor = nullptr; 590 OP_REQUIRES_OK( 591 ctx, ctx->forward_input_or_allocate_output( 592 {"wci"}, "wci_grad", wci_tensor->shape(), &wci_grad_tensor)); 593 594 Tensor* wcf_grad_tensor = nullptr; 595 OP_REQUIRES_OK( 596 ctx, ctx->forward_input_or_allocate_output( 597 {"wcf"}, "wcf_grad", wcf_tensor->shape(), &wcf_grad_tensor)); 598 599 Tensor* wco_grad_tensor = nullptr; 600 OP_REQUIRES_OK( 601 ctx, ctx->forward_input_or_allocate_output( 602 {"wco"}, "wco_grad", wco_tensor->shape(), &wco_grad_tensor)); 603 604 // Allocate our temp tensors. 605 Tensor do_tensor; 606 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 607 TensorShape({batch_size, cell_size}), 608 &do_tensor)); 609 610 Tensor dcs_tensor; 611 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 612 TensorShape({batch_size, cell_size}), 613 &dcs_tensor)); 614 615 Tensor dci_tensor; 616 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 617 TensorShape({batch_size, cell_size}), 618 &dci_tensor)); 619 620 Tensor df_tensor; 621 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 622 TensorShape({batch_size, cell_size}), 623 &df_tensor)); 624 625 Tensor di_tensor; 626 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 627 TensorShape({batch_size, cell_size}), 628 &di_tensor)); 629 630 const Device& device = ctx->eigen_device<Device>(); 631 632 functor::TensorZero<Device, T>()(device, wci_grad_tensor->flat<float>()); 633 functor::TensorZero<Device, T>()(device, wcf_grad_tensor->flat<float>()); 634 functor::TensorZero<Device, T>()(device, wco_grad_tensor->flat<float>()); 635 636 functor::LSTMBlockCellBprop<Device, T, USE_CUBLAS>(batch_size, input_size, 637 cell_size)( 638 ctx, device, use_peephole_, x_tensor->matrix<T>(), 639 cs_prev_tensor->matrix<T>(), h_prev_tensor->matrix<T>(), 640 w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(), 641 wco_tensor->vec<T>(), b_tensor->vec<T>(), i_tensor->matrix<T>(), 642 cs_tensor->matrix<T>(), f_tensor->matrix<T>(), o_tensor->matrix<T>(), 643 ci_tensor->matrix<T>(), co_tensor->matrix<T>(), 644 cs_grad_tensor->matrix<T>(), h_grad_tensor->matrix<T>(), 645 do_tensor.matrix<T>(), dcs_tensor.matrix<T>(), dci_tensor.matrix<T>(), 646 df_tensor.matrix<T>(), di_tensor.matrix<T>(), dicfo_tensor->matrix<T>(), 647 cs_prev_grad_tensor->matrix<T>(), wci_grad_tensor->vec<T>(), 648 wcf_grad_tensor->vec<T>(), wco_grad_tensor->vec<T>()); 649 } 650 651 protected: 652 bool use_peephole_; 653 }; 654 655 #define REGISTER_KERNEL(T) \ 656 REGISTER_KERNEL_BUILDER( \ 657 Name("LSTMBlockCellGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 658 LSTMBlockCellGradOp<CPUDevice, T, false>); 659 REGISTER_KERNEL(float); 660 // REGISTER_KERNEL(double); 661 #undef REGISTER_KERNEL 662 663 #if GOOGLE_CUDA 664 namespace functor { 665 #define DECLARE_GPU_SPEC(T) \ 666 template <> \ 667 void LSTMBlockCellBprop<GPUDevice, T, true>::operator()( \ 668 OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \ 669 typename TTypes<T>::ConstMatrix x, \ 670 typename TTypes<T>::ConstMatrix cs_prev, \ 671 typename TTypes<T>::ConstMatrix h_prev, \ 672 typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \ 673 typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \ 674 typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i, \ 675 typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f, \ 676 typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci, \ 677 typename TTypes<T>::ConstMatrix co, \ 678 typename TTypes<T>::ConstMatrix cs_grad, \ 679 typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \ 680 typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, \ 681 typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, \ 682 typename TTypes<T>::Matrix dicfo, \ 683 typename TTypes<T>::Matrix cs_prev_grad, \ 684 typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, \ 685 typename TTypes<T>::Vec wco_grad); \ 686 \ 687 extern template struct LSTMBlockCellBprop<GPUDevice, T, \ 688 true /* USE_CUBLAS */>; 689 690 DECLARE_GPU_SPEC(float); 691 // DECLARE_GPU_SPEC(double); 692 #undef DECLARE_GPU_SPEC 693 } // namespace functor 694 695 #define REGISTER_GPU_KERNEL(T) \ 696 REGISTER_KERNEL_BUILDER( \ 697 Name("LSTMBlockCellGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 698 LSTMBlockCellGradOp<GPUDevice, T, true>); 699 700 REGISTER_GPU_KERNEL(float); 701 // REGISTER_GPU_KERNEL(double); 702 #undef REGISTER_GPU_KERNEL 703 #endif // GOOGLE_CUDA 704 705 namespace { 706 707 // This helper class can be used to access timeslices of a 3D tensor. If a slice 708 // happens to be unaligned (usually because both batch size and number of cells 709 // are odd - this isn't common) this involves overhead, since data needs to be 710 // copied. However, if all slices are aligned, the bits aren't copied. In the 711 // cases where copying is needed, the outputs have to be recopied back. 712 // At the end of each time step you should call FinishTimeStep which does this, 713 // and also allows for reuse of temporary tensors. 714 template <typename Device, typename T> 715 class SliceHelper { 716 public: 717 explicit SliceHelper(OpKernelContext* ctx) 718 : ctx_(ctx), device_(ctx_->eigen_device<Device>()) {} 719 720 ~SliceHelper() { 721 CHECK(copy_out_.empty()); 722 for (const auto& entry : pool_) { 723 CHECK(!entry.second.second); // nothing is in use 724 } 725 } 726 727 // Slice through an input tensor. This may copy unaligned slices, but no 728 // copying back will be done at the end. 729 const Tensor InputSlice(const Tensor& t, int pos, const string& name) { 730 Tensor res = UnalignedSlice(t, pos); 731 if (res.IsAligned()) { 732 return res; 733 } else { 734 return AlignTensor(res, name); 735 } 736 } 737 738 // Slice through an output tensor. This may copy unaligned slices, and 739 // schedule copying back on destruction. 740 Tensor OutputSlice(Tensor* t, int pos, const string& name) { 741 Tensor res = UnalignedSlice(*t, pos); 742 if (res.IsAligned()) { 743 return res; 744 } else { 745 Tensor aligned = AlignTensor(res, name); 746 copy_out_.emplace_back(res, aligned); 747 return aligned; 748 } 749 } 750 751 void FinishTimeStep() { 752 for (const auto& p : copy_out_) { 753 const Tensor& aligned = p.second; 754 Tensor original = p.first; 755 // Copy from aligned back to original. 756 functor::TensorCopyToUnaligned<Device, T>()(device_, aligned.flat<T>(), 757 original.unaligned_flat<T>()); 758 } 759 copy_out_.clear(); 760 // Mark all entries as not in use. 761 for (auto& entry : pool_) { 762 entry.second.second = false; 763 } 764 } 765 766 private: 767 // Return a slice at position 'pos'. Result may be unaligned. The resulting 768 // tensor always shares data with the source tensor. 769 Tensor UnalignedSlice(const Tensor& t, int pos) const { 770 Tensor res; 771 // CHECK should never fail here, since the number of elements must match 772 CHECK(res.CopyFrom(t.Slice(pos, pos + 1), {t.dim_size(1), t.dim_size(2)})); 773 return res; 774 } 775 776 // Assumes input is not aligned, creates a temporary aligned tensor of the 777 // same shape and copies the original tensor's content into it. 778 Tensor AlignTensor(const Tensor& t, const string& name) { 779 VLOG(1) << "AlignTensor called for " << name << ", shape " 780 << t.shape().DebugString() 781 << ". This is unnecessary copying. Consider using shapes with even " 782 << "sizes"; 783 Tensor aligned; 784 auto found = pool_.find(name); 785 if (found != pool_.end()) { // found in pool 786 CHECK(!found->second.second) << "Tensor " << name << " is in use"; 787 found->second.second = true; // mark in use 788 aligned = found->second.first; 789 CHECK(aligned.shape().IsSameSize(t.shape())); 790 CHECK_EQ(aligned.dtype(), t.dtype()); 791 } else { // allocate a new temporary tensor 792 TF_CHECK_OK(ctx_->allocate_temp(t.dtype(), t.shape(), &aligned)); 793 pool_.emplace(name, std::make_pair(aligned, true)); 794 } 795 functor::TensorCopyUnaligned<Device, T>()(device_, t.unaligned_flat<T>(), 796 aligned.flat<T>()); 797 return aligned; 798 } 799 800 // Tensors to be copied. 801 std::vector<std::pair<Tensor, const Tensor>> copy_out_; 802 // A pool of pre-allocated temporary tensors, with an indicator for whether 803 // it's in use. 804 std::map<string, std::pair<Tensor, bool>> pool_; 805 // Op context 806 OpKernelContext* ctx_ = nullptr; 807 // Device 808 const Device& device_; 809 }; 810 811 } // namespace 812 813 template <typename Device, typename T, bool USE_CUBLAS> 814 class BlockLSTMOp : public OpKernel { 815 public: 816 explicit BlockLSTMOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 817 OP_REQUIRES_OK(ctx, ctx->GetAttr("forget_bias", &forget_bias_)); 818 OP_REQUIRES_OK(ctx, ctx->GetAttr("cell_clip", &cell_clip_)); 819 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_)); 820 } 821 822 void Compute(OpKernelContext* ctx) override { 823 const Tensor* seq_len_max_tensor = nullptr; 824 OP_REQUIRES_OK(ctx, ctx->input("seq_len_max", &seq_len_max_tensor)); 825 826 const Tensor* x; 827 OP_REQUIRES_OK(ctx, ctx->input("x", &x)); 828 OP_REQUIRES(ctx, x->dims() == 3, errors::InvalidArgument("x must be 3D")); 829 const int64 timelen = x->dim_size(0); 830 const int64 batch_size = x->dim_size(1); 831 const int64 input_size = x->dim_size(2); 832 833 const Tensor* cs_prev_tensor = nullptr; 834 OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor)); 835 OP_REQUIRES(ctx, cs_prev_tensor->dims() == 2, 836 errors::InvalidArgument("cs_prev must be 2D")); 837 OP_REQUIRES(ctx, cs_prev_tensor->dim_size(0) == batch_size, 838 errors::InvalidArgument("cs_prev.dims(0) != batch_size: ", 839 cs_prev_tensor->dim_size(0), " vs. ", 840 batch_size)); 841 const int64 cell_size = cs_prev_tensor->dim_size(1); 842 843 if (batch_size * input_size % 2 == 1) { 844 LOG(WARNING) << "BlockLSTMOp is inefficient when both batch_size and " 845 << "input_size are odd. You are using: batch_size=" 846 << batch_size << ", input_size=" << input_size; 847 } 848 if (batch_size * cell_size % 2 == 1) { 849 LOG(WARNING) << "BlockLSTMOp is inefficient when both batch_size and " 850 << "cell_size are odd. You are using: batch_size=" 851 << batch_size << ", cell_size=" << cell_size; 852 } 853 854 const Tensor* h_prev_tensor = nullptr; 855 OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor)); 856 OP_REQUIRES(ctx, h_prev_tensor->dims() == 2, 857 errors::InvalidArgument("h_prev must be 2D")); 858 OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size, 859 errors::InvalidArgument("h_prev.dims(0) != batch_size: ", 860 h_prev_tensor->dim_size(0), " vs. ", 861 batch_size)); 862 OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size, 863 errors::InvalidArgument( 864 "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1), 865 " vs. ", cell_size)); 866 867 const Tensor* w_tensor = nullptr; 868 OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor)); 869 OP_REQUIRES(ctx, w_tensor->dims() == 2, 870 errors::InvalidArgument("w must be 2D")); 871 OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size, 872 errors::InvalidArgument( 873 "w.dim_size(0) != input_size + cell_size: ", 874 w_tensor->dim_size(0), " vs. ", input_size + cell_size)); 875 OP_REQUIRES(ctx, w_tensor->dim_size(1) == cell_size * 4, 876 errors::InvalidArgument( 877 "w.dim_size(1) != cell_size * 4: ", w_tensor->dim_size(1), 878 " vs. ", cell_size * 4)); 879 880 const Tensor* wci_tensor = nullptr; 881 OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor)); 882 OP_REQUIRES(ctx, wci_tensor->dims() == 1, 883 errors::InvalidArgument("wci must be 1D")); 884 OP_REQUIRES(ctx, wci_tensor->dim_size(0) == cell_size, 885 errors::InvalidArgument( 886 "wci.dim_size(0) != cell_size: ", wci_tensor->dim_size(0), 887 " vs. ", cell_size)); 888 889 const Tensor* wcf_tensor = nullptr; 890 OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor)); 891 OP_REQUIRES(ctx, wcf_tensor->dims() == 1, 892 errors::InvalidArgument("wcf must be 1D")); 893 OP_REQUIRES(ctx, wcf_tensor->dim_size(0) == cell_size, 894 errors::InvalidArgument( 895 "wcf.dim_size(0) != cell_size: ", wcf_tensor->dim_size(0), 896 " vs. ", cell_size)); 897 898 const Tensor* wco_tensor = nullptr; 899 OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor)); 900 OP_REQUIRES(ctx, wco_tensor->dims() == 1, 901 errors::InvalidArgument("wco must be 1D")); 902 OP_REQUIRES(ctx, wco_tensor->dim_size(0) == cell_size, 903 errors::InvalidArgument( 904 "wco.dim_size(0) != cell_size: ", wco_tensor->dim_size(0), 905 " vs. ", cell_size)); 906 907 const Tensor* b_tensor = nullptr; 908 OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor)); 909 OP_REQUIRES(ctx, b_tensor->dims() == 1, 910 errors::InvalidArgument("b must be 1D")); 911 OP_REQUIRES(ctx, b_tensor->dim_size(0) == cell_size * 4, 912 errors::InvalidArgument( 913 "b.dim_size(0) != cell_size * 4: ", b_tensor->dim_size(0), 914 " vs. ", cell_size * 4)); 915 916 TensorShape batch_cell_shape({timelen, batch_size, cell_size}); 917 Tensor* i_out; 918 OP_REQUIRES_OK(ctx, ctx->allocate_output("i", batch_cell_shape, &i_out)); 919 920 Tensor* cs_out; 921 OP_REQUIRES_OK(ctx, ctx->allocate_output("cs", batch_cell_shape, &cs_out)); 922 923 Tensor* f_out; 924 OP_REQUIRES_OK(ctx, ctx->allocate_output("f", batch_cell_shape, &f_out)); 925 926 Tensor* o_out; 927 OP_REQUIRES_OK(ctx, ctx->allocate_output("o", batch_cell_shape, &o_out)); 928 929 Tensor* ci_out; 930 OP_REQUIRES_OK(ctx, ctx->allocate_output("ci", batch_cell_shape, &ci_out)); 931 932 Tensor* co_out; 933 OP_REQUIRES_OK(ctx, ctx->allocate_output("co", batch_cell_shape, &co_out)); 934 935 Tensor* h_out; 936 OP_REQUIRES_OK(ctx, ctx->allocate_output("h", batch_cell_shape, &h_out)); 937 938 Tensor xh_tensor; 939 OP_REQUIRES_OK(ctx, ctx->allocate_temp( 940 DataTypeToEnum<T>::v(), 941 TensorShape({batch_size, input_size + cell_size}), 942 &xh_tensor)); 943 944 Tensor icfo_tensor; 945 OP_REQUIRES_OK(ctx, 946 ctx->allocate_temp(DataTypeToEnum<T>::v(), 947 TensorShape({batch_size, cell_size * 4}), 948 &icfo_tensor)); 949 950 const Device& device = ctx->eigen_device<Device>(); 951 952 const int64 seq_len_max = seq_len_max_tensor->scalar<int64>()(); 953 SliceHelper<Device, T> slicer(ctx); 954 for (int64 t = 0; t < seq_len_max; ++t) { 955 const Tensor x_tensor = slicer.InputSlice(*x, t, "x"); 956 const Tensor& cs_prev_tensor2 = 957 t == 0 ? *cs_prev_tensor 958 : slicer.OutputSlice(cs_out, t - 1, "cs_prev"); 959 const Tensor& h_prev_tensor2 = 960 t == 0 ? *h_prev_tensor : slicer.OutputSlice(h_out, t - 1, "h_prev"); 961 962 Tensor i_tensor = slicer.OutputSlice(i_out, t, "i_out"); 963 Tensor cs_tensor = slicer.OutputSlice(cs_out, t, "cs_out"); 964 Tensor f_tensor = slicer.OutputSlice(f_out, t, "f_out"); 965 Tensor o_tensor = slicer.OutputSlice(o_out, t, "o_out"); 966 Tensor ci_tensor = slicer.OutputSlice(ci_out, t, "ci_out"); 967 Tensor co_tensor = slicer.OutputSlice(co_out, t, "co_out"); 968 Tensor h_tensor = slicer.OutputSlice(h_out, t, "h_out"); 969 970 functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size, 971 cell_size)( 972 ctx, device, forget_bias_, cell_clip_, use_peephole_, 973 x_tensor.matrix<T>(), cs_prev_tensor2.matrix<T>(), 974 h_prev_tensor2.matrix<T>(), w_tensor->matrix<T>(), 975 wci_tensor->vec<T>(), wcf_tensor->vec<T>(), wco_tensor->vec<T>(), 976 b_tensor->vec<T>(), xh_tensor.matrix<T>(), i_tensor.matrix<T>(), 977 cs_tensor.matrix<T>(), f_tensor.matrix<T>(), o_tensor.matrix<T>(), 978 ci_tensor.matrix<T>(), co_tensor.matrix<T>(), icfo_tensor.matrix<T>(), 979 h_tensor.matrix<T>()); 980 slicer.FinishTimeStep(); 981 } 982 983 if (seq_len_max < timelen) { 984 Tensor cs_tensor = cs_out->Slice(seq_len_max, timelen); 985 Tensor h_tensor = h_out->Slice(seq_len_max, timelen); 986 987 functor::TensorUnalignedZero<Device, T>()( 988 device, cs_tensor.unaligned_flat<float>()); 989 functor::TensorUnalignedZero<Device, T>()( 990 device, h_tensor.unaligned_flat<float>()); 991 } 992 } 993 994 private: 995 float forget_bias_; 996 float cell_clip_; 997 bool use_peephole_; 998 }; 999 1000 #define REGISTER_KERNEL(T) \ 1001 REGISTER_KERNEL_BUILDER( \ 1002 Name("BlockLSTM").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 1003 BlockLSTMOp<CPUDevice, T, false>); 1004 REGISTER_KERNEL(float); 1005 // REGISTER_KERNEL(double); 1006 #undef REGISTER_KERNEL 1007 1008 #if GOOGLE_CUDA 1009 namespace functor { 1010 #define DECLARE_GPU_SPEC(T) \ 1011 template <> \ 1012 void TensorZero<GPUDevice, T>::operator()(const GPUDevice& d, \ 1013 typename TTypes<T>::Flat t); \ 1014 \ 1015 extern template struct TensorZero<GPUDevice, T>; \ 1016 \ 1017 template <> \ 1018 void TensorUnalignedZero<GPUDevice, T>::operator()( \ 1019 const GPUDevice& d, typename TTypes<T>::UnalignedFlat t); \ 1020 \ 1021 extern template struct TensorUnalignedZero<GPUDevice, T>; 1022 1023 DECLARE_GPU_SPEC(float); 1024 // DECLARE_GPU_SPEC(double); 1025 #undef DECLARE_GPU_SPEC 1026 } // end namespace functor 1027 1028 #define REGISTER_GPU_KERNEL(T) \ 1029 REGISTER_KERNEL_BUILDER(Name("BlockLSTM") \ 1030 .Device(DEVICE_GPU) \ 1031 .HostMemory("seq_len_max") \ 1032 .TypeConstraint<T>("T"), \ 1033 BlockLSTMOp<GPUDevice, T, true>); 1034 1035 REGISTER_GPU_KERNEL(float); 1036 // REGISTER_GPU_KERNEL(double); 1037 #undef REGISTER_GPU_KERNEL 1038 #endif // GOOGLE_CUDA 1039 1040 template <typename Device, typename T, bool USE_CUBLAS> 1041 class BlockLSTMGradOp : public OpKernel { 1042 public: 1043 explicit BlockLSTMGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 1044 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_)); 1045 } 1046 1047 void Compute(OpKernelContext* ctx) override { 1048 const Tensor* seq_len_max_tensor = nullptr; 1049 OP_REQUIRES_OK(ctx, ctx->input("seq_len_max", &seq_len_max_tensor)); 1050 1051 const Tensor* x; 1052 OP_REQUIRES_OK(ctx, ctx->input("x", &x)); 1053 OP_REQUIRES(ctx, x->dims() == 3, errors::InvalidArgument("x must be 3D")); 1054 const int64 timelen = x->dim_size(0); 1055 const int64 batch_size = x->dim_size(1); 1056 const int64 input_size = x->dim_size(2); 1057 1058 const Tensor* cs_prev_tensor = nullptr; 1059 OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor)); 1060 1061 const Tensor* h_prev_tensor = nullptr; 1062 OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor)); 1063 1064 const Tensor* w_tensor = nullptr; 1065 OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor)); 1066 const int64 cell_size = w_tensor->dim_size(1) / 4; 1067 OP_REQUIRES(ctx, input_size + cell_size == w_tensor->dim_size(0), 1068 errors::InvalidArgument( 1069 "w matrix rows don't match: ", input_size + cell_size, 1070 " vs. ", w_tensor->dim_size(0))); 1071 1072 const Tensor* wci_tensor = nullptr; 1073 OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor)); 1074 1075 const Tensor* wcf_tensor = nullptr; 1076 OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor)); 1077 1078 const Tensor* wco_tensor = nullptr; 1079 OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor)); 1080 1081 const Tensor* b_tensor = nullptr; 1082 OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor)); 1083 OP_REQUIRES( 1084 ctx, cell_size == b_tensor->dim_size(0) / 4, 1085 errors::InvalidArgument("w and b cell_size don't match: ", cell_size, 1086 " vs. ", b_tensor->dim_size(0))); 1087 1088 const Tensor* i_out = nullptr; 1089 OP_REQUIRES_OK(ctx, ctx->input("i", &i_out)); 1090 1091 const Tensor* cs_out = nullptr; 1092 OP_REQUIRES_OK(ctx, ctx->input("cs", &cs_out)); 1093 1094 const Tensor* f_out = nullptr; 1095 OP_REQUIRES_OK(ctx, ctx->input("f", &f_out)); 1096 1097 const Tensor* o_out = nullptr; 1098 OP_REQUIRES_OK(ctx, ctx->input("o", &o_out)); 1099 1100 const Tensor* ci_out = nullptr; 1101 OP_REQUIRES_OK(ctx, ctx->input("ci", &ci_out)); 1102 1103 const Tensor* co_out = nullptr; 1104 OP_REQUIRES_OK(ctx, ctx->input("co", &co_out)); 1105 1106 const Tensor* h_out = nullptr; 1107 OP_REQUIRES_OK(ctx, ctx->input("h", &h_out)); 1108 1109 const Tensor* cs_grad = nullptr; 1110 OP_REQUIRES_OK(ctx, ctx->input("cs_grad", &cs_grad)); 1111 1112 const Tensor* h_grad = nullptr; 1113 OP_REQUIRES_OK(ctx, ctx->input("h_grad", &h_grad)); 1114 1115 TensorShape batch_input_shape({timelen, batch_size, input_size}); 1116 Tensor* x_grad; 1117 OP_REQUIRES_OK(ctx, 1118 ctx->allocate_output("x_grad", batch_input_shape, &x_grad)); 1119 1120 Tensor* cs_prev_grad_tensor = nullptr; 1121 OP_REQUIRES_OK(ctx, 1122 ctx->allocate_output("cs_prev_grad", cs_prev_tensor->shape(), 1123 &cs_prev_grad_tensor)); 1124 1125 Tensor* h_prev_grad_tensor = nullptr; 1126 OP_REQUIRES_OK(ctx, 1127 ctx->allocate_output("h_prev_grad", h_prev_tensor->shape(), 1128 &h_prev_grad_tensor)); 1129 1130 Tensor* w_grad_tensor = nullptr; 1131 OP_REQUIRES_OK( 1132 ctx, ctx->allocate_output("w_grad", w_tensor->shape(), &w_grad_tensor)); 1133 1134 Tensor* wci_grad_tensor = nullptr; 1135 OP_REQUIRES_OK(ctx, ctx->allocate_output("wci_grad", wci_tensor->shape(), 1136 &wci_grad_tensor)); 1137 1138 Tensor* wcf_grad_tensor = nullptr; 1139 OP_REQUIRES_OK(ctx, ctx->allocate_output("wcf_grad", wcf_tensor->shape(), 1140 &wcf_grad_tensor)); 1141 1142 Tensor* wco_grad_tensor = nullptr; 1143 OP_REQUIRES_OK(ctx, ctx->allocate_output("wco_grad", wco_tensor->shape(), 1144 &wco_grad_tensor)); 1145 1146 Tensor* b_grad_tensor = nullptr; 1147 OP_REQUIRES_OK( 1148 ctx, ctx->allocate_output("b_grad", b_tensor->shape(), &b_grad_tensor)); 1149 1150 TensorShape batch_cell_shape({batch_size, cell_size}); 1151 1152 Tensor xh_tensor; 1153 OP_REQUIRES_OK(ctx, ctx->allocate_temp( 1154 DataTypeToEnum<T>::v(), 1155 TensorShape({batch_size, input_size + cell_size}), 1156 &xh_tensor)); 1157 1158 Tensor xh_grad_tensor; 1159 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 1160 xh_tensor.shape(), &xh_grad_tensor)); 1161 1162 Tensor do_tensor; 1163 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 1164 batch_cell_shape, &do_tensor)); 1165 1166 Tensor dcs_tensor; 1167 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 1168 batch_cell_shape, &dcs_tensor)); 1169 1170 Tensor dci_tensor; 1171 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 1172 batch_cell_shape, &dci_tensor)); 1173 1174 Tensor df_tensor; 1175 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 1176 batch_cell_shape, &df_tensor)); 1177 1178 Tensor di_tensor; 1179 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 1180 batch_cell_shape, &di_tensor)); 1181 1182 Tensor dicfo_tensor; 1183 OP_REQUIRES_OK(ctx, 1184 ctx->allocate_temp(DataTypeToEnum<T>::v(), 1185 TensorShape({batch_size, cell_size * 4}), 1186 &dicfo_tensor)); 1187 1188 Tensor cs_grad_tensor; 1189 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 1190 batch_cell_shape, &cs_grad_tensor)); 1191 1192 Tensor h_grad_tensor; 1193 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 1194 batch_cell_shape, &h_grad_tensor)); 1195 1196 const Device& device = ctx->eigen_device<Device>(); 1197 1198 functor::TensorZero<Device, T>()(device, cs_grad_tensor.flat<float>()); 1199 functor::TensorZero<Device, T>()(device, 1200 cs_prev_grad_tensor->flat<float>()); 1201 functor::TensorZero<Device, T>()(device, h_grad_tensor.flat<float>()); 1202 functor::TensorZero<Device, T>()(device, h_prev_grad_tensor->flat<float>()); 1203 functor::TensorZero<Device, T>()(device, w_grad_tensor->flat<float>()); 1204 functor::TensorZero<Device, T>()(device, wci_grad_tensor->flat<float>()); 1205 functor::TensorZero<Device, T>()(device, wcf_grad_tensor->flat<float>()); 1206 functor::TensorZero<Device, T>()(device, wco_grad_tensor->flat<float>()); 1207 functor::TensorZero<Device, T>()(device, b_grad_tensor->flat<float>()); 1208 1209 const int64 seq_len_max = seq_len_max_tensor->scalar<int64>()(); 1210 SliceHelper<Device, T> slicer(ctx); 1211 for (int64 t = seq_len_max - 1; t >= 0; --t) { 1212 const Tensor& x_tensor = slicer.InputSlice(*x, t, "x"); 1213 const Tensor& cs_prev_tensor2 = 1214 t == 0 ? *cs_prev_tensor 1215 : slicer.InputSlice(*cs_out, t - 1, "cs_prev"); 1216 const Tensor& h_prev_tensor2 = 1217 t == 0 ? *h_prev_tensor : slicer.InputSlice(*h_out, t - 1, "h_prev"); 1218 const Tensor& i_tensor = slicer.InputSlice(*i_out, t, "i_out"); 1219 const Tensor& cs_tensor = slicer.InputSlice(*cs_out, t, "cs_out"); 1220 const Tensor& f_tensor = slicer.InputSlice(*f_out, t, "f_out"); 1221 const Tensor& o_tensor = slicer.InputSlice(*o_out, t, "o_out"); 1222 const Tensor& ci_tensor = slicer.InputSlice(*ci_out, t, "ci_out"); 1223 const Tensor& co_tensor = slicer.InputSlice(*co_out, t, "co_out"); 1224 1225 // Grab previous CS grad. 1226 const Tensor& const_cs_prev_grad_tensor = *cs_prev_grad_tensor; 1227 const Tensor const_cs_grad_slice = 1228 slicer.InputSlice(*cs_grad, t, "cs_grad"); 1229 functor::TensorAdd<Device, T>()( 1230 device, const_cs_prev_grad_tensor.flat<T>(), 1231 const_cs_grad_slice.flat<T>(), cs_grad_tensor.flat<T>()); 1232 1233 // Combine previous h grad and h grad coming on top. 1234 const Tensor& const_h_prev_grad_tensor = *h_prev_grad_tensor; 1235 const Tensor const_h_grad_slice = slicer.InputSlice(*h_grad, t, "h_grad"); 1236 functor::TensorAdd<Device, T>()( 1237 device, const_h_prev_grad_tensor.flat<T>(), 1238 const_h_grad_slice.flat<T>(), h_grad_tensor.flat<T>()); 1239 1240 const Tensor& const_cs_grad_tensor = cs_grad_tensor; 1241 const Tensor& const_h_grad_tensor = h_grad_tensor; 1242 1243 Tensor x_grad_tensor = slicer.OutputSlice(x_grad, t, "x_grad"); 1244 functor::BlockLSTMBprop<Device, T, USE_CUBLAS>(batch_size, input_size, 1245 cell_size)( 1246 ctx, device, use_peephole_, x_tensor.matrix<T>(), 1247 cs_prev_tensor2.matrix<T>(), h_prev_tensor2.matrix<T>(), 1248 w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(), 1249 wco_tensor->vec<T>(), b_tensor->vec<T>(), xh_tensor.matrix<T>(), 1250 i_tensor.matrix<T>(), cs_tensor.matrix<T>(), f_tensor.matrix<T>(), 1251 o_tensor.matrix<T>(), ci_tensor.matrix<T>(), co_tensor.matrix<T>(), 1252 const_cs_grad_tensor.matrix<T>(), const_h_grad_tensor.matrix<T>(), 1253 do_tensor.matrix<T>(), dcs_tensor.matrix<T>(), dci_tensor.matrix<T>(), 1254 df_tensor.matrix<T>(), di_tensor.matrix<T>(), 1255 dicfo_tensor.matrix<T>(), cs_prev_grad_tensor->matrix<T>(), 1256 h_prev_grad_tensor->matrix<T>(), xh_grad_tensor.matrix<T>(), 1257 x_grad_tensor.matrix<T>(), w_grad_tensor->matrix<T>(), 1258 wci_grad_tensor->vec<T>(), wcf_grad_tensor->vec<T>(), 1259 wco_grad_tensor->vec<T>(), b_grad_tensor->vec<T>()); 1260 slicer.FinishTimeStep(); 1261 } 1262 1263 if (seq_len_max < timelen) { 1264 Tensor x_grad_tensor = x_grad->Slice(seq_len_max, timelen); 1265 functor::TensorUnalignedZero<Device, T>()( 1266 device, x_grad_tensor.unaligned_flat<T>()); 1267 } 1268 } 1269 1270 private: 1271 bool use_peephole_; 1272 }; 1273 1274 #define REGISTER_KERNEL(T) \ 1275 REGISTER_KERNEL_BUILDER( \ 1276 Name("BlockLSTMGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 1277 BlockLSTMGradOp<CPUDevice, T, false>); 1278 REGISTER_KERNEL(float); 1279 // REGISTER_KERNEL(double); 1280 #undef REGISTER_KERNEL 1281 1282 #if GOOGLE_CUDA 1283 namespace functor { 1284 #define DECLARE_GPU_SPEC(T) \ 1285 template <> \ 1286 void TensorCopy<GPUDevice, T>::operator()(const GPUDevice& d, \ 1287 typename TTypes<T>::ConstFlat src, \ 1288 typename TTypes<T>::Flat dst); \ 1289 \ 1290 template <> \ 1291 void TensorCopyUnaligned<GPUDevice, T>::operator()( \ 1292 const GPUDevice& d, typename TTypes<T>::UnalignedConstFlat src, \ 1293 typename TTypes<T>::Flat dst); \ 1294 \ 1295 template <> \ 1296 void TensorCopyToUnaligned<GPUDevice, T>::operator()( \ 1297 const GPUDevice& d, typename TTypes<T>::ConstFlat src, \ 1298 typename TTypes<T>::UnalignedFlat dst); \ 1299 \ 1300 template <> \ 1301 void TensorAdd<GPUDevice, T>::operator()( \ 1302 const GPUDevice& d, typename TTypes<T>::ConstFlat a, \ 1303 typename TTypes<T>::ConstFlat b, typename TTypes<T>::Flat c); \ 1304 \ 1305 template <> \ 1306 void BlockLSTMBprop<GPUDevice, T, true>::operator()( \ 1307 OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \ 1308 typename TTypes<T>::ConstMatrix x, \ 1309 typename TTypes<T>::ConstMatrix cs_prev, \ 1310 typename TTypes<T>::ConstMatrix h_prev, \ 1311 typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \ 1312 typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \ 1313 typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \ 1314 typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs, \ 1315 typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o, \ 1316 typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co, \ 1317 typename TTypes<T>::ConstMatrix cs_grad, \ 1318 typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \ 1319 typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, \ 1320 typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, \ 1321 typename TTypes<T>::Matrix dicfo, \ 1322 typename TTypes<T>::Matrix cs_prev_grad, \ 1323 typename TTypes<T>::Matrix h_prev_grad, \ 1324 typename TTypes<T>::Matrix xh_grad, typename TTypes<T>::Matrix x_grad, \ 1325 typename TTypes<T>::Matrix w_grad, typename TTypes<T>::Vec wci_grad, \ 1326 typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad, \ 1327 typename TTypes<T>::Vec b_grad); \ 1328 \ 1329 extern template struct TensorCopy<GPUDevice, T>; \ 1330 extern template struct TensorAdd<GPUDevice, T>; \ 1331 extern template struct BlockLSTMBprop<GPUDevice, T, true>; 1332 1333 DECLARE_GPU_SPEC(float); 1334 // DECLARE_GPU_SPEC(double); 1335 #undef DECLARE_GPU_SPEC 1336 } // end namespace functor 1337 1338 #define REGISTER_GPU_KERNEL(T) \ 1339 REGISTER_KERNEL_BUILDER(Name("BlockLSTMGrad") \ 1340 .Device(DEVICE_GPU) \ 1341 .HostMemory("seq_len_max") \ 1342 .TypeConstraint<T>("T"), \ 1343 BlockLSTMGradOp<GPUDevice, T, true>); 1344 1345 REGISTER_GPU_KERNEL(float); 1346 // REGISTER_GPU_KERNEL(double); 1347 #undef REGISTER_GPU_KERNEL 1348 #endif // GOOGLE_CUDA 1349 1350 } // end namespace tensorflow 1351