1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #if GOOGLE_CUDA 19 #define EIGEN_USE_GPU 20 #endif // GOOGLE_CUDA 21 22 #include "tensorflow/contrib/rnn/kernels/lstm_ops.h" 23 24 #include <memory> 25 #include <vector> 26 27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/framework/register_types.h" 30 #include "tensorflow/core/framework/tensor.h" 31 #include "tensorflow/core/framework/tensor_shape.h" 32 #include "tensorflow/core/framework/tensor_types.h" 33 #include "tensorflow/core/framework/types.h" 34 #include "tensorflow/core/platform/logging.h" 35 #include "tensorflow/core/platform/macros.h" 36 37 namespace tensorflow { 38 39 typedef Eigen::ThreadPoolDevice CPUDevice; 40 typedef Eigen::GpuDevice GPUDevice; 41 42 namespace functor { 43 44 template <typename T> 45 void LSTMBlockCellFpropWithEigen( 46 const LSTMBlockCell& cell, OpKernelContext* ctx, const CPUDevice& d, 47 const float forget_bias, const float cell_clip, bool use_peephole, 48 typename TTypes<T>::ConstMatrix x, typename TTypes<T>::ConstMatrix cs_prev, 49 typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w, 50 typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf, 51 typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b, 52 typename TTypes<T>::Matrix xh, typename TTypes<T>::Matrix i, 53 typename TTypes<T>::Matrix cs, typename TTypes<T>::Matrix f, 54 typename TTypes<T>::Matrix o, typename TTypes<T>::Matrix ci, 55 typename TTypes<T>::Matrix co, typename TTypes<T>::Matrix icfo, 56 typename TTypes<T>::Matrix h) { 57 // Concat xh = [x, h]. 58 xh.slice(cell.xh_x_offsets(), cell.xh_x_extents()).device(d) = x; 59 xh.slice(cell.xh_h_offsets(), cell.xh_h_extents()).device(d) = h_prev; 60 61 // states1 = xh * w + b 62 typename TTypes<T>::ConstMatrix const_xh(xh.data(), xh.dimensions()); 63 TensorBlasGemm<CPUDevice, T, false /* USE_CUBLAS */>::compute( 64 ctx, d, false, false, typename gemm_compute_type<T>::type(1.f), const_xh, 65 w, typename gemm_compute_type<T>::type(0.f), icfo); 66 Eigen::array<Eigen::DenseIndex, 2> b_shape({1, b.dimensions()[0]}); 67 Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({cell.batch_size(), 1}); 68 icfo.device(d) += b.reshape(b_shape).broadcast(broadcast_shape); 69 70 Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell.cell_size()}); 71 Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({cell.batch_size(), 1}); 72 73 // Input gate. 74 if (use_peephole) { 75 auto i_peep = cs_prev * wci.reshape(p_shape).broadcast(p_broadcast_shape); 76 i.device(d) = 77 (icfo.slice(cell.icfo_i_offsets(), cell.cell_extents()) + i_peep) 78 .sigmoid(); 79 } else { 80 i.device(d) = 81 icfo.slice(cell.icfo_i_offsets(), cell.cell_extents()).sigmoid(); 82 } 83 84 // Cell input. 85 ci.device(d) = icfo.slice(cell.icfo_c_offsets(), cell.cell_extents()).tanh(); 86 87 // Forget gate (w/ bias). 88 if (use_peephole) { 89 auto f_peep = cs_prev * wcf.reshape(p_shape).broadcast(p_broadcast_shape); 90 f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) + 91 f.constant(T(forget_bias)) + f_peep) 92 .sigmoid(); 93 } else { 94 f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) + 95 f.constant(T(forget_bias))) 96 .sigmoid(); 97 } 98 99 // cs = ci .* i + f .* cs_prev 100 cs.device(d) = i * ci + f * cs_prev; 101 102 if (cell_clip > 0.0f) { 103 cs.device(d) = 104 cs.binaryExpr(cs.constant(T(cell_clip)), Eigen::scalar_clip_op<T>()); 105 } 106 107 // co = tanh(cs) 108 co.device(d) = cs.tanh(); 109 110 // Output gate. 111 if (use_peephole) { 112 auto o_peep = cs * wco.reshape(p_shape).broadcast(p_broadcast_shape); 113 o.device(d) = 114 (icfo.slice(cell.icfo_o_offsets(), cell.cell_extents()) + o_peep) 115 .sigmoid(); 116 } else { 117 o.device(d) = 118 icfo.slice(cell.icfo_o_offsets(), cell.cell_extents()).sigmoid(); 119 } 120 121 // h = o .* co 122 h.device(d) = o * co; 123 } 124 125 template <typename Device, typename T, bool USE_CUBLAS> 126 void LSTMBlockCellBpropWithEigen( 127 const LSTMBlockCell& cell, OpKernelContext* ctx, const Device& d, 128 bool use_peephole, typename TTypes<T>::ConstMatrix x, 129 typename TTypes<T>::ConstMatrix cs_prev, 130 typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w, 131 typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf, 132 typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b, 133 typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs, 134 typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o, 135 typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co, 136 typename TTypes<T>::ConstMatrix cs_grad, 137 typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, 138 typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, 139 typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, 140 typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad, 141 typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, 142 typename TTypes<T>::Vec wco_grad) { 143 // do[t] = sigm'(o[t]) .* dh[t] .* co[t] 144 do_.device(d) = o * (o.constant(T(1)) - o) * h_grad * co; 145 146 // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1] 147 dcs.device(d) = (co.constant(T(1)) - co * co) * h_grad * o + cs_grad; 148 149 Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell.cell_size()}); 150 Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({cell.batch_size(), 1}); 151 if (use_peephole) { 152 dcs.device(d) = 153 dcs + do_ * wco.reshape(p_shape).broadcast(p_broadcast_shape); 154 } 155 156 // dci[t] = tanh'(ci[t]) dcs[t] i[t] 157 dci.device(d) = (ci.constant(T(1)) - ci * ci) * dcs * i; 158 159 // df[t] = sigm'(f[t]) dcs[t] cs[t - 1] 160 df.device(d) = f * (f.constant(T(1)) - f) * dcs * cs_prev; 161 162 // di[t] = sigm'(i[t]) dcs[t] ci[t] 163 di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci; 164 165 dicfo.slice(cell.icfo_i_offsets(), cell.cell_extents()).device(d) = di; 166 dicfo.slice(cell.icfo_c_offsets(), cell.cell_extents()).device(d) = dci; 167 dicfo.slice(cell.icfo_f_offsets(), cell.cell_extents()).device(d) = df; 168 dicfo.slice(cell.icfo_o_offsets(), cell.cell_extents()).device(d) = do_; 169 170 cs_prev_grad.device(d) = dcs * f; 171 if (use_peephole) { 172 cs_prev_grad.device(d) = 173 cs_prev_grad + di * wci.reshape(p_shape).broadcast(p_broadcast_shape) + 174 df * wcf.reshape(p_shape).broadcast(p_broadcast_shape); 175 wci_grad.device(d) = (di * cs_prev).sum(Eigen::array<int, 1>({0})); 176 wcf_grad.device(d) = (df * cs_prev).sum(Eigen::array<int, 1>({0})); 177 wco_grad.device(d) = (do_ * cs).sum(Eigen::array<int, 1>({0})); 178 } 179 } 180 181 #define DEFINE_CPU_SPECS(T) \ 182 template <> \ 183 void LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */>::operator()( \ 184 OpKernelContext* ctx, const CPUDevice& d, const float forget_bias, \ 185 const float cell_clip, bool use_peephole, \ 186 typename TTypes<T>::ConstMatrix x, \ 187 typename TTypes<T>::ConstMatrix cs_prev, \ 188 typename TTypes<T>::ConstMatrix h_prev, \ 189 typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \ 190 typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \ 191 typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \ 192 typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, \ 193 typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, \ 194 typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \ 195 typename TTypes<T>::Matrix icfo, typename TTypes<T>::Matrix h) { \ 196 LSTMBlockCellFpropWithEigen<T>( \ 197 *this, ctx, d, forget_bias, cell_clip, use_peephole, x, cs_prev, \ 198 h_prev, w, wci, wcf, wco, b, xh, i, cs, f, o, ci, co, icfo, h); \ 199 } \ 200 template <> \ 201 void LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */>::operator()( \ 202 OpKernelContext* ctx, const CPUDevice& d, bool use_peephole, \ 203 typename TTypes<T>::ConstMatrix x, \ 204 typename TTypes<T>::ConstMatrix cs_prev, \ 205 typename TTypes<T>::ConstMatrix h_prev, \ 206 typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \ 207 typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \ 208 typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i, \ 209 typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f, \ 210 typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci, \ 211 typename TTypes<T>::ConstMatrix co, \ 212 typename TTypes<T>::ConstMatrix cs_grad, \ 213 typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \ 214 typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, \ 215 typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, \ 216 typename TTypes<T>::Matrix dicfo, \ 217 typename TTypes<T>::Matrix cs_prev_grad, \ 218 typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, \ 219 typename TTypes<T>::Vec wco_grad) { \ 220 LSTMBlockCellBpropWithEigen<CPUDevice, T, false /* USE_CUBLAS */>( \ 221 *this, ctx, d, use_peephole, x, cs_prev, h_prev, w, wci, wcf, wco, b, \ 222 i, cs, f, o, ci, co, cs_grad, h_grad, do_, dcs, dci, df, di, dicfo, \ 223 cs_prev_grad, wci_grad, wcf_grad, wco_grad); \ 224 } \ 225 template struct LSTMBlockCellFprop<CPUDevice, T, false /* USE_CUBLAS */>; \ 226 template struct LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */>; 227 228 DEFINE_CPU_SPECS(float); 229 DEFINE_CPU_SPECS(Eigen::half); 230 #undef DEFINE_CPU_SPECS 231 232 } // namespace functor 233 234 template <typename Device, typename T, bool USE_CUBLAS> 235 class LSTMBlockCellOp : public OpKernel { 236 public: 237 explicit LSTMBlockCellOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 238 OP_REQUIRES_OK(ctx, ctx->GetAttr("forget_bias", &forget_bias_)); 239 OP_REQUIRES_OK(ctx, ctx->GetAttr("cell_clip", &cell_clip_)); 240 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_)); 241 } 242 243 void Compute(OpKernelContext* ctx) override { 244 const Tensor* x_tensor = nullptr; 245 OP_REQUIRES_OK(ctx, ctx->input("x", &x_tensor)); 246 247 const Tensor* cs_prev_tensor = nullptr; 248 OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor)); 249 250 const Tensor* h_prev_tensor = nullptr; 251 OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor)); 252 253 const Tensor* w_tensor = nullptr; 254 OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor)); 255 256 const Tensor* wci_tensor = nullptr; 257 OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor)); 258 259 const Tensor* wcf_tensor = nullptr; 260 OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor)); 261 262 const Tensor* wco_tensor = nullptr; 263 OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor)); 264 265 const Tensor* b_tensor = nullptr; 266 OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor)); 267 268 const int64 batch_size = x_tensor->dim_size(0); 269 const int64 input_size = x_tensor->dim_size(1); 270 const int64 cell_size = cs_prev_tensor->dim_size(1); 271 272 // Sanity checks for our input shapes. 273 OP_REQUIRES(ctx, cs_prev_tensor->dim_size(0) == batch_size, 274 errors::InvalidArgument("cs_prev.dims(0) != batch_size: ", 275 cs_prev_tensor->dim_size(0), " vs. ", 276 batch_size)); 277 OP_REQUIRES(ctx, cs_prev_tensor->dim_size(1) == cell_size, 278 errors::InvalidArgument("cs_prev.dims(1) != cell_size: ", 279 cs_prev_tensor->dim_size(1), " vs. ", 280 cell_size)); 281 282 OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size, 283 errors::InvalidArgument("h_prev.dims(0) != batch_size: ", 284 h_prev_tensor->dim_size(0), " vs. ", 285 batch_size)); 286 OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size, 287 errors::InvalidArgument( 288 "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1), 289 " vs. ", cell_size)); 290 291 OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size, 292 errors::InvalidArgument( 293 "w.dim_size(0) != input_size + cell_size: ", 294 w_tensor->dim_size(0), " vs. ", input_size + cell_size)); 295 OP_REQUIRES(ctx, w_tensor->dim_size(1) == cell_size * 4, 296 errors::InvalidArgument( 297 "w.dim_size(1) != cell_size * 4: ", w_tensor->dim_size(1), 298 " vs. ", cell_size * 4)); 299 300 OP_REQUIRES(ctx, b_tensor->dim_size(0) == cell_size * 4, 301 errors::InvalidArgument( 302 "b.dim_size(0) != cell_size * 4: ", b_tensor->dim_size(0), 303 " vs. ", cell_size * 4)); 304 305 // Allocate our output tensors. 306 Tensor* i_tensor = nullptr; 307 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( 308 {"h_prev"}, "i", 309 TensorShape({batch_size, cell_size}), &i_tensor)); 310 311 Tensor* cs_tensor = nullptr; 312 OP_REQUIRES_OK( 313 ctx, ctx->allocate_output("cs", TensorShape({batch_size, cell_size}), 314 &cs_tensor)); 315 316 Tensor* f_tensor = nullptr; 317 OP_REQUIRES_OK( 318 ctx, ctx->allocate_output("f", TensorShape({batch_size, cell_size}), 319 &f_tensor)); 320 321 Tensor* o_tensor = nullptr; 322 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( 323 {"cs_prev"}, "o", 324 TensorShape({batch_size, cell_size}), &o_tensor)); 325 326 Tensor* ci_tensor = nullptr; 327 OP_REQUIRES_OK( 328 ctx, ctx->allocate_output("ci", TensorShape({batch_size, cell_size}), 329 &ci_tensor)); 330 331 Tensor* co_tensor = nullptr; 332 OP_REQUIRES_OK( 333 ctx, ctx->allocate_output("co", TensorShape({batch_size, cell_size}), 334 &co_tensor)); 335 336 Tensor* h_tensor = nullptr; 337 OP_REQUIRES_OK( 338 ctx, ctx->allocate_output("h", TensorShape({batch_size, cell_size}), 339 &h_tensor)); 340 341 // Allocate our temp tensors. 342 Tensor xh_tensor; 343 OP_REQUIRES_OK(ctx, ctx->allocate_temp( 344 DataTypeToEnum<T>::v(), 345 TensorShape({batch_size, input_size + cell_size}), 346 &xh_tensor)); 347 348 Tensor icfo_tensor; 349 OP_REQUIRES_OK(ctx, 350 ctx->allocate_temp(DataTypeToEnum<T>::v(), 351 TensorShape({batch_size, cell_size * 4}), 352 &icfo_tensor)); 353 354 const Device& device = ctx->eigen_device<Device>(); 355 356 functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size, 357 cell_size)( 358 ctx, device, forget_bias_, cell_clip_, use_peephole_, 359 x_tensor->matrix<T>(), cs_prev_tensor->matrix<T>(), 360 h_prev_tensor->matrix<T>(), w_tensor->matrix<T>(), wci_tensor->vec<T>(), 361 wcf_tensor->vec<T>(), wco_tensor->vec<T>(), b_tensor->vec<T>(), 362 xh_tensor.matrix<T>(), i_tensor->matrix<T>(), cs_tensor->matrix<T>(), 363 f_tensor->matrix<T>(), o_tensor->matrix<T>(), ci_tensor->matrix<T>(), 364 co_tensor->matrix<T>(), icfo_tensor.matrix<T>(), h_tensor->matrix<T>()); 365 } 366 367 private: 368 float forget_bias_; 369 float cell_clip_; 370 bool use_peephole_; 371 }; 372 373 #define REGISTER_KERNEL(T) \ 374 REGISTER_KERNEL_BUILDER( \ 375 Name("LSTMBlockCell").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 376 LSTMBlockCellOp<CPUDevice, T, false>); 377 REGISTER_KERNEL(float); 378 REGISTER_KERNEL(Eigen::half); 379 #undef REGISTER_KERNEL 380 381 #if GOOGLE_CUDA 382 namespace functor { 383 #define DECLARE_GPU_SPEC(T) \ 384 template <> \ 385 void LSTMBlockCellFprop<GPUDevice, T, true>::operator()( \ 386 OpKernelContext* ctx, const GPUDevice& d, const float forget_bias, \ 387 const float cell_clip, bool use_peephole, \ 388 typename TTypes<T>::ConstMatrix x, \ 389 typename TTypes<T>::ConstMatrix cs_prev, \ 390 typename TTypes<T>::ConstMatrix h_prev, \ 391 typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \ 392 typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \ 393 typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \ 394 typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs, \ 395 typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o, \ 396 typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co, \ 397 typename TTypes<T>::Matrix icfo, typename TTypes<T>::Matrix h); \ 398 \ 399 extern template struct LSTMBlockCellFprop<GPUDevice, T, true>; 400 401 DECLARE_GPU_SPEC(float); 402 DECLARE_GPU_SPEC(Eigen::half); 403 #undef DECLARE_GPU_SPEC 404 } // end namespace functor 405 406 #define REGISTER_GPU_KERNEL(T) \ 407 REGISTER_KERNEL_BUILDER( \ 408 Name("LSTMBlockCell").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 409 LSTMBlockCellOp<GPUDevice, T, true>); 410 411 REGISTER_GPU_KERNEL(float); 412 REGISTER_GPU_KERNEL(Eigen::half); 413 // REGISTER_GPU_KERNEL(double); 414 #undef REGISTER_GPU_KERNEL 415 #endif // GOOGLE_CUDA 416 417 template <typename Device, typename T, bool USE_CUBLAS> 418 class LSTMBlockCellGradOp : public OpKernel { 419 public: 420 explicit LSTMBlockCellGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 421 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_)); 422 } 423 424 void Compute(OpKernelContext* ctx) override { 425 const Tensor* x_tensor = nullptr; 426 OP_REQUIRES_OK(ctx, ctx->input("x", &x_tensor)); 427 428 const Tensor* cs_prev_tensor = nullptr; 429 OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor)); 430 431 const Tensor* h_prev_tensor = nullptr; 432 OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor)); 433 434 const Tensor* w_tensor = nullptr; 435 OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor)); 436 437 const Tensor* wci_tensor = nullptr; 438 OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor)); 439 440 const Tensor* wcf_tensor = nullptr; 441 OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor)); 442 443 const Tensor* wco_tensor = nullptr; 444 OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor)); 445 446 const Tensor* b_tensor = nullptr; 447 OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor)); 448 449 const Tensor* i_tensor = nullptr; 450 OP_REQUIRES_OK(ctx, ctx->input("i", &i_tensor)); 451 452 const Tensor* cs_tensor = nullptr; 453 OP_REQUIRES_OK(ctx, ctx->input("cs", &cs_tensor)); 454 455 const Tensor* f_tensor = nullptr; 456 OP_REQUIRES_OK(ctx, ctx->input("f", &f_tensor)); 457 458 const Tensor* o_tensor = nullptr; 459 OP_REQUIRES_OK(ctx, ctx->input("o", &o_tensor)); 460 461 const Tensor* ci_tensor = nullptr; 462 OP_REQUIRES_OK(ctx, ctx->input("ci", &ci_tensor)); 463 464 const Tensor* co_tensor = nullptr; 465 OP_REQUIRES_OK(ctx, ctx->input("co", &co_tensor)); 466 467 const Tensor* cs_grad_tensor = nullptr; 468 OP_REQUIRES_OK(ctx, ctx->input("cs_grad", &cs_grad_tensor)); 469 470 const Tensor* h_grad_tensor = nullptr; 471 OP_REQUIRES_OK(ctx, ctx->input("h_grad", &h_grad_tensor)); 472 473 const int64 batch_size = x_tensor->dim_size(0); 474 const int64 input_size = x_tensor->dim_size(1); 475 const int64 cell_size = cs_prev_tensor->dim_size(1); 476 477 // Sanity checks for our input shapes. 478 OP_REQUIRES(ctx, cs_prev_tensor->dim_size(0) == batch_size, 479 errors::InvalidArgument("cs_prev.dims(0) != batch_size: ", 480 cs_prev_tensor->dim_size(0), " vs. ", 481 batch_size)); 482 OP_REQUIRES(ctx, cs_prev_tensor->dim_size(1) == cell_size, 483 errors::InvalidArgument("cs_prev.dims(1) != cell_size: ", 484 cs_prev_tensor->dim_size(1), " vs. ", 485 cell_size)); 486 487 OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size, 488 errors::InvalidArgument("h_prev.dims(0) != batch_size: ", 489 h_prev_tensor->dim_size(0), " vs. ", 490 batch_size)); 491 OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size, 492 errors::InvalidArgument( 493 "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1), 494 " vs. ", cell_size)); 495 496 OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size, 497 errors::InvalidArgument( 498 "w.dim_size(0) != input_size + cell_size: ", 499 w_tensor->dim_size(0), " vs. ", input_size + cell_size)); 500 OP_REQUIRES(ctx, w_tensor->dim_size(1) == cell_size * 4, 501 errors::InvalidArgument( 502 "w.dim_size(1) != cell_size * 4: ", w_tensor->dim_size(1), 503 " vs. ", cell_size * 4)); 504 505 OP_REQUIRES(ctx, b_tensor->dim_size(0) == cell_size * 4, 506 errors::InvalidArgument( 507 "b.dim_size(0) != cell_size * 4: ", b_tensor->dim_size(0), 508 " vs. ", cell_size * 4)); 509 510 OP_REQUIRES(ctx, i_tensor->dim_size(0) == batch_size, 511 errors::InvalidArgument( 512 "i.dim_size(0) != batch_size: ", i_tensor->dim_size(0), 513 " vs. ", batch_size)); 514 OP_REQUIRES(ctx, i_tensor->dim_size(1) == cell_size, 515 errors::InvalidArgument( 516 "i.dim_size(1) != cell_size: ", i_tensor->dim_size(1), 517 " vs. ", cell_size)); 518 519 OP_REQUIRES(ctx, cs_tensor->dim_size(0) == batch_size, 520 errors::InvalidArgument( 521 "cs.dim_size(0) != batch_size: ", cs_tensor->dim_size(0), 522 " vs. ", batch_size)); 523 OP_REQUIRES(ctx, cs_tensor->dim_size(1) == cell_size, 524 errors::InvalidArgument( 525 "cs.dim_size(1) != cell_size: ", cs_tensor->dim_size(1), 526 " vs. ", cell_size)); 527 528 OP_REQUIRES(ctx, f_tensor->dim_size(0) == batch_size, 529 errors::InvalidArgument( 530 "f.dim_size(0) != batch_size: ", f_tensor->dim_size(0), 531 " vs. ", batch_size)); 532 OP_REQUIRES(ctx, f_tensor->dim_size(1) == cell_size, 533 errors::InvalidArgument( 534 "i.dim_size(1) != cell_size: ", f_tensor->dim_size(1), 535 " vs. ", cell_size)); 536 537 OP_REQUIRES(ctx, o_tensor->dim_size(0) == batch_size, 538 errors::InvalidArgument( 539 "o.dim_size(0) != batch_size: ", o_tensor->dim_size(0), 540 " vs. ", batch_size)); 541 OP_REQUIRES(ctx, o_tensor->dim_size(1) == cell_size, 542 errors::InvalidArgument( 543 "o.dim_size(1) != cell_size: ", o_tensor->dim_size(1), 544 " vs. ", cell_size)); 545 546 OP_REQUIRES(ctx, ci_tensor->dim_size(0) == batch_size, 547 errors::InvalidArgument( 548 "ci.dim_size(0) != batch_size: ", ci_tensor->dim_size(0), 549 " vs. ", batch_size)); 550 OP_REQUIRES(ctx, ci_tensor->dim_size(1) == cell_size, 551 errors::InvalidArgument( 552 "ci.dim_size(1) != cell_size: ", ci_tensor->dim_size(1), 553 " vs. ", cell_size)); 554 555 OP_REQUIRES(ctx, co_tensor->dim_size(0) == batch_size, 556 errors::InvalidArgument( 557 "co.dim_size(0) != batch_size: ", co_tensor->dim_size(0), 558 " vs. ", batch_size)); 559 OP_REQUIRES(ctx, co_tensor->dim_size(1) == cell_size, 560 errors::InvalidArgument( 561 "co.dim_size(1) != cell_size: ", co_tensor->dim_size(1), 562 " vs. ", cell_size)); 563 564 OP_REQUIRES(ctx, cs_grad_tensor->dim_size(0) == batch_size, 565 errors::InvalidArgument( 566 "cs_grad_tensor.dims(0) != batch_size: ", 567 cs_grad_tensor->dim_size(0), " vs. ", batch_size)); 568 OP_REQUIRES(ctx, cs_grad_tensor->dim_size(1) == cell_size, 569 errors::InvalidArgument("cs_grad_tensor.dims(1) != cell_size: ", 570 cs_grad_tensor->dim_size(1), " vs. ", 571 cell_size)); 572 573 OP_REQUIRES(ctx, h_grad_tensor->dim_size(0) == batch_size, 574 errors::InvalidArgument("h_grad_tensor.dims(0) != batch_size: ", 575 h_grad_tensor->dim_size(0), " vs. ", 576 batch_size)); 577 OP_REQUIRES(ctx, h_grad_tensor->dim_size(1) == cell_size, 578 errors::InvalidArgument("h_grad_tensor.dims(1) != cell_size: ", 579 h_grad_tensor->dim_size(1), " vs. ", 580 cell_size)); 581 582 // Allocate our output tensors. 583 Tensor* cs_prev_grad_tensor = nullptr; 584 OP_REQUIRES_OK( 585 ctx, ctx->forward_input_or_allocate_output( 586 {"cs_grad"}, "cs_prev_grad", 587 TensorShape({batch_size, cell_size}), &cs_prev_grad_tensor)); 588 589 Tensor* dicfo_tensor = nullptr; 590 OP_REQUIRES_OK(ctx, ctx->allocate_output( 591 "dicfo", TensorShape({batch_size, cell_size * 4}), 592 &dicfo_tensor)); 593 594 Tensor* wci_grad_tensor = nullptr; 595 OP_REQUIRES_OK( 596 ctx, ctx->forward_input_or_allocate_output( 597 {"wci"}, "wci_grad", wci_tensor->shape(), &wci_grad_tensor)); 598 599 Tensor* wcf_grad_tensor = nullptr; 600 OP_REQUIRES_OK( 601 ctx, ctx->forward_input_or_allocate_output( 602 {"wcf"}, "wcf_grad", wcf_tensor->shape(), &wcf_grad_tensor)); 603 604 Tensor* wco_grad_tensor = nullptr; 605 OP_REQUIRES_OK( 606 ctx, ctx->forward_input_or_allocate_output( 607 {"wco"}, "wco_grad", wco_tensor->shape(), &wco_grad_tensor)); 608 609 // Allocate our temp tensors. 610 Tensor do_tensor; 611 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 612 TensorShape({batch_size, cell_size}), 613 &do_tensor)); 614 615 Tensor dcs_tensor; 616 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 617 TensorShape({batch_size, cell_size}), 618 &dcs_tensor)); 619 620 Tensor dci_tensor; 621 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 622 TensorShape({batch_size, cell_size}), 623 &dci_tensor)); 624 625 Tensor df_tensor; 626 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 627 TensorShape({batch_size, cell_size}), 628 &df_tensor)); 629 630 Tensor di_tensor; 631 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 632 TensorShape({batch_size, cell_size}), 633 &di_tensor)); 634 635 const Device& device = ctx->eigen_device<Device>(); 636 637 functor::TensorZero<Device, T>()(device, wci_grad_tensor->flat<T>()); 638 functor::TensorZero<Device, T>()(device, wcf_grad_tensor->flat<T>()); 639 functor::TensorZero<Device, T>()(device, wco_grad_tensor->flat<T>()); 640 641 functor::LSTMBlockCellBprop<Device, T, USE_CUBLAS>(batch_size, input_size, 642 cell_size)( 643 ctx, device, use_peephole_, x_tensor->matrix<T>(), 644 cs_prev_tensor->matrix<T>(), h_prev_tensor->matrix<T>(), 645 w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(), 646 wco_tensor->vec<T>(), b_tensor->vec<T>(), i_tensor->matrix<T>(), 647 cs_tensor->matrix<T>(), f_tensor->matrix<T>(), o_tensor->matrix<T>(), 648 ci_tensor->matrix<T>(), co_tensor->matrix<T>(), 649 cs_grad_tensor->matrix<T>(), h_grad_tensor->matrix<T>(), 650 do_tensor.matrix<T>(), dcs_tensor.matrix<T>(), dci_tensor.matrix<T>(), 651 df_tensor.matrix<T>(), di_tensor.matrix<T>(), dicfo_tensor->matrix<T>(), 652 cs_prev_grad_tensor->matrix<T>(), wci_grad_tensor->vec<T>(), 653 wcf_grad_tensor->vec<T>(), wco_grad_tensor->vec<T>()); 654 } 655 656 protected: 657 bool use_peephole_; 658 }; 659 660 #define REGISTER_KERNEL(T) \ 661 REGISTER_KERNEL_BUILDER( \ 662 Name("LSTMBlockCellGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 663 LSTMBlockCellGradOp<CPUDevice, T, false>); 664 REGISTER_KERNEL(float); 665 REGISTER_KERNEL(Eigen::half); 666 #undef REGISTER_KERNEL 667 668 #if GOOGLE_CUDA 669 namespace functor { 670 #define DECLARE_GPU_SPEC(T) \ 671 template <> \ 672 void LSTMBlockCellBprop<GPUDevice, T, true>::operator()( \ 673 OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \ 674 typename TTypes<T>::ConstMatrix x, \ 675 typename TTypes<T>::ConstMatrix cs_prev, \ 676 typename TTypes<T>::ConstMatrix h_prev, \ 677 typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \ 678 typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \ 679 typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i, \ 680 typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f, \ 681 typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci, \ 682 typename TTypes<T>::ConstMatrix co, \ 683 typename TTypes<T>::ConstMatrix cs_grad, \ 684 typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \ 685 typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, \ 686 typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, \ 687 typename TTypes<T>::Matrix dicfo, \ 688 typename TTypes<T>::Matrix cs_prev_grad, \ 689 typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad, \ 690 typename TTypes<T>::Vec wco_grad); \ 691 \ 692 extern template struct LSTMBlockCellBprop<GPUDevice, T, \ 693 true /* USE_CUBLAS */>; 694 695 DECLARE_GPU_SPEC(float); 696 DECLARE_GPU_SPEC(Eigen::half); 697 // DECLARE_GPU_SPEC(double); 698 #undef DECLARE_GPU_SPEC 699 } // namespace functor 700 701 #define REGISTER_GPU_KERNEL(T) \ 702 REGISTER_KERNEL_BUILDER( \ 703 Name("LSTMBlockCellGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 704 LSTMBlockCellGradOp<GPUDevice, T, true>); 705 706 REGISTER_GPU_KERNEL(float); 707 REGISTER_GPU_KERNEL(Eigen::half); 708 // REGISTER_GPU_KERNEL(double); 709 #undef REGISTER_GPU_KERNEL 710 #endif // GOOGLE_CUDA 711 712 namespace { 713 714 // This helper class can be used to access timeslices of a 3D tensor. If a slice 715 // happens to be unaligned (usually because both batch size and number of cells 716 // are odd - this isn't common) this involves overhead, since data needs to be 717 // copied. However, if all slices are aligned, the bits aren't copied. In the 718 // cases where copying is needed, the outputs have to be recopied back. 719 // At the end of each time step you should call FinishTimeStep which does this, 720 // and also allows for reuse of temporary tensors. 721 template <typename Device, typename T> 722 class SliceHelper { 723 public: 724 explicit SliceHelper(OpKernelContext* ctx) 725 : ctx_(ctx), device_(ctx_->eigen_device<Device>()) {} 726 727 ~SliceHelper() { 728 CHECK(copy_out_.empty()); 729 for (const auto& entry : pool_) { 730 CHECK(!entry.second.second); // nothing is in use 731 } 732 } 733 734 // Slice through an input tensor. This may copy unaligned slices, but no 735 // copying back will be done at the end. 736 const Tensor InputSlice(const Tensor& t, int pos, const string& name) { 737 Tensor res = UnalignedSlice(t, pos); 738 if (res.IsAligned()) { 739 return res; 740 } else { 741 return AlignTensor(res, name); 742 } 743 } 744 745 // Slice through an output tensor. This may copy unaligned slices, and 746 // schedule copying back on destruction. 747 Tensor OutputSlice(Tensor* t, int pos, const string& name) { 748 Tensor res = UnalignedSlice(*t, pos); 749 if (res.IsAligned()) { 750 return res; 751 } else { 752 Tensor aligned = AlignTensor(res, name); 753 copy_out_.emplace_back(res, aligned); 754 return aligned; 755 } 756 } 757 758 void FinishTimeStep() { 759 for (const auto& p : copy_out_) { 760 const Tensor& aligned = p.second; 761 Tensor original = p.first; 762 // Copy from aligned back to original. 763 functor::TensorCopyToUnaligned<Device, T>()(device_, aligned.flat<T>(), 764 original.unaligned_flat<T>()); 765 } 766 copy_out_.clear(); 767 // Mark all entries as not in use. 768 for (auto& entry : pool_) { 769 entry.second.second = false; 770 } 771 } 772 773 private: 774 // Return a slice at position 'pos'. Result may be unaligned. The resulting 775 // tensor always shares data with the source tensor. 776 Tensor UnalignedSlice(const Tensor& t, int pos) const { 777 Tensor res; 778 // CHECK should never fail here, since the number of elements must match 779 CHECK(res.CopyFrom(t.Slice(pos, pos + 1), {t.dim_size(1), t.dim_size(2)})); 780 return res; 781 } 782 783 // Assumes input is not aligned, creates a temporary aligned tensor of the 784 // same shape and copies the original tensor's content into it. 785 Tensor AlignTensor(const Tensor& t, const string& name) { 786 VLOG(1) << "AlignTensor called for " << name << ", shape " 787 << t.shape().DebugString() 788 << ". This is unnecessary copying. Consider using shapes with even " 789 << "sizes"; 790 Tensor aligned; 791 auto found = pool_.find(name); 792 if (found != pool_.end()) { // found in pool 793 CHECK(!found->second.second) << "Tensor " << name << " is in use"; 794 found->second.second = true; // mark in use 795 aligned = found->second.first; 796 CHECK(aligned.shape().IsSameSize(t.shape())); 797 CHECK_EQ(aligned.dtype(), t.dtype()); 798 } else { // allocate a new temporary tensor 799 TF_CHECK_OK(ctx_->allocate_temp(t.dtype(), t.shape(), &aligned)); 800 pool_.emplace(name, std::make_pair(aligned, true)); 801 } 802 functor::TensorCopyUnaligned<Device, T>()(device_, t.unaligned_flat<T>(), 803 aligned.flat<T>()); 804 return aligned; 805 } 806 807 // Tensors to be copied. 808 std::vector<std::pair<Tensor, const Tensor>> copy_out_; 809 // A pool of pre-allocated temporary tensors, with an indicator for whether 810 // it's in use. 811 std::map<string, std::pair<Tensor, bool>> pool_; 812 // Op context 813 OpKernelContext* ctx_ = nullptr; 814 // Device 815 const Device& device_; 816 }; 817 818 } // namespace 819 820 template <typename Device, typename T, bool USE_CUBLAS> 821 class BlockLSTMOp : public OpKernel { 822 public: 823 explicit BlockLSTMOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 824 OP_REQUIRES_OK(ctx, ctx->GetAttr("forget_bias", &forget_bias_)); 825 OP_REQUIRES_OK(ctx, ctx->GetAttr("cell_clip", &cell_clip_)); 826 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_)); 827 } 828 829 void Compute(OpKernelContext* ctx) override { 830 const Tensor* seq_len_max_tensor = nullptr; 831 OP_REQUIRES_OK(ctx, ctx->input("seq_len_max", &seq_len_max_tensor)); 832 833 const Tensor* x; 834 OP_REQUIRES_OK(ctx, ctx->input("x", &x)); 835 OP_REQUIRES(ctx, x->dims() == 3, errors::InvalidArgument("x must be 3D")); 836 const int64 timelen = x->dim_size(0); 837 const int64 batch_size = x->dim_size(1); 838 const int64 input_size = x->dim_size(2); 839 840 const Tensor* cs_prev_tensor = nullptr; 841 OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor)); 842 OP_REQUIRES(ctx, cs_prev_tensor->dims() == 2, 843 errors::InvalidArgument("cs_prev must be 2D")); 844 OP_REQUIRES(ctx, cs_prev_tensor->dim_size(0) == batch_size, 845 errors::InvalidArgument("cs_prev.dims(0) != batch_size: ", 846 cs_prev_tensor->dim_size(0), " vs. ", 847 batch_size)); 848 const int64 cell_size = cs_prev_tensor->dim_size(1); 849 850 if (batch_size * input_size % 2 == 1) { 851 LOG(WARNING) << "BlockLSTMOp is inefficient when both batch_size and " 852 << "input_size are odd. You are using: batch_size=" 853 << batch_size << ", input_size=" << input_size; 854 } 855 if (batch_size * cell_size % 2 == 1) { 856 LOG(WARNING) << "BlockLSTMOp is inefficient when both batch_size and " 857 << "cell_size are odd. You are using: batch_size=" 858 << batch_size << ", cell_size=" << cell_size; 859 } 860 861 const Tensor* h_prev_tensor = nullptr; 862 OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor)); 863 OP_REQUIRES(ctx, h_prev_tensor->dims() == 2, 864 errors::InvalidArgument("h_prev must be 2D")); 865 OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size, 866 errors::InvalidArgument("h_prev.dims(0) != batch_size: ", 867 h_prev_tensor->dim_size(0), " vs. ", 868 batch_size)); 869 OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size, 870 errors::InvalidArgument( 871 "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1), 872 " vs. ", cell_size)); 873 874 const Tensor* w_tensor = nullptr; 875 OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor)); 876 OP_REQUIRES(ctx, w_tensor->dims() == 2, 877 errors::InvalidArgument("w must be 2D")); 878 OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size, 879 errors::InvalidArgument( 880 "w.dim_size(0) != input_size + cell_size: ", 881 w_tensor->dim_size(0), " vs. ", input_size + cell_size)); 882 OP_REQUIRES(ctx, w_tensor->dim_size(1) == cell_size * 4, 883 errors::InvalidArgument( 884 "w.dim_size(1) != cell_size * 4: ", w_tensor->dim_size(1), 885 " vs. ", cell_size * 4)); 886 887 const Tensor* wci_tensor = nullptr; 888 OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor)); 889 OP_REQUIRES(ctx, wci_tensor->dims() == 1, 890 errors::InvalidArgument("wci must be 1D")); 891 OP_REQUIRES(ctx, wci_tensor->dim_size(0) == cell_size, 892 errors::InvalidArgument( 893 "wci.dim_size(0) != cell_size: ", wci_tensor->dim_size(0), 894 " vs. ", cell_size)); 895 896 const Tensor* wcf_tensor = nullptr; 897 OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor)); 898 OP_REQUIRES(ctx, wcf_tensor->dims() == 1, 899 errors::InvalidArgument("wcf must be 1D")); 900 OP_REQUIRES(ctx, wcf_tensor->dim_size(0) == cell_size, 901 errors::InvalidArgument( 902 "wcf.dim_size(0) != cell_size: ", wcf_tensor->dim_size(0), 903 " vs. ", cell_size)); 904 905 const Tensor* wco_tensor = nullptr; 906 OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor)); 907 OP_REQUIRES(ctx, wco_tensor->dims() == 1, 908 errors::InvalidArgument("wco must be 1D")); 909 OP_REQUIRES(ctx, wco_tensor->dim_size(0) == cell_size, 910 errors::InvalidArgument( 911 "wco.dim_size(0) != cell_size: ", wco_tensor->dim_size(0), 912 " vs. ", cell_size)); 913 914 const Tensor* b_tensor = nullptr; 915 OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor)); 916 OP_REQUIRES(ctx, b_tensor->dims() == 1, 917 errors::InvalidArgument("b must be 1D")); 918 OP_REQUIRES(ctx, b_tensor->dim_size(0) == cell_size * 4, 919 errors::InvalidArgument( 920 "b.dim_size(0) != cell_size * 4: ", b_tensor->dim_size(0), 921 " vs. ", cell_size * 4)); 922 923 TensorShape batch_cell_shape({timelen, batch_size, cell_size}); 924 Tensor* i_out; 925 OP_REQUIRES_OK(ctx, ctx->allocate_output("i", batch_cell_shape, &i_out)); 926 927 Tensor* cs_out; 928 OP_REQUIRES_OK(ctx, ctx->allocate_output("cs", batch_cell_shape, &cs_out)); 929 930 Tensor* f_out; 931 OP_REQUIRES_OK(ctx, ctx->allocate_output("f", batch_cell_shape, &f_out)); 932 933 Tensor* o_out; 934 OP_REQUIRES_OK(ctx, ctx->allocate_output("o", batch_cell_shape, &o_out)); 935 936 Tensor* ci_out; 937 OP_REQUIRES_OK(ctx, ctx->allocate_output("ci", batch_cell_shape, &ci_out)); 938 939 Tensor* co_out; 940 OP_REQUIRES_OK(ctx, ctx->allocate_output("co", batch_cell_shape, &co_out)); 941 942 Tensor* h_out; 943 OP_REQUIRES_OK(ctx, ctx->allocate_output("h", batch_cell_shape, &h_out)); 944 945 Tensor xh_tensor; 946 OP_REQUIRES_OK(ctx, ctx->allocate_temp( 947 DataTypeToEnum<T>::v(), 948 TensorShape({batch_size, input_size + cell_size}), 949 &xh_tensor)); 950 951 Tensor icfo_tensor; 952 OP_REQUIRES_OK(ctx, 953 ctx->allocate_temp(DataTypeToEnum<T>::v(), 954 TensorShape({batch_size, cell_size * 4}), 955 &icfo_tensor)); 956 957 const Device& device = ctx->eigen_device<Device>(); 958 959 const int64 seq_len_max = seq_len_max_tensor->scalar<int64>()(); 960 SliceHelper<Device, T> slicer(ctx); 961 for (int64 t = 0; t < seq_len_max; ++t) { 962 const Tensor x_tensor = slicer.InputSlice(*x, t, "x"); 963 const Tensor& cs_prev_tensor2 = 964 t == 0 ? *cs_prev_tensor 965 : slicer.OutputSlice(cs_out, t - 1, "cs_prev"); 966 const Tensor& h_prev_tensor2 = 967 t == 0 ? *h_prev_tensor : slicer.OutputSlice(h_out, t - 1, "h_prev"); 968 969 Tensor i_tensor = slicer.OutputSlice(i_out, t, "i_out"); 970 Tensor cs_tensor = slicer.OutputSlice(cs_out, t, "cs_out"); 971 Tensor f_tensor = slicer.OutputSlice(f_out, t, "f_out"); 972 Tensor o_tensor = slicer.OutputSlice(o_out, t, "o_out"); 973 Tensor ci_tensor = slicer.OutputSlice(ci_out, t, "ci_out"); 974 Tensor co_tensor = slicer.OutputSlice(co_out, t, "co_out"); 975 Tensor h_tensor = slicer.OutputSlice(h_out, t, "h_out"); 976 977 functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size, 978 cell_size)( 979 ctx, device, forget_bias_, cell_clip_, use_peephole_, 980 x_tensor.matrix<T>(), cs_prev_tensor2.matrix<T>(), 981 h_prev_tensor2.matrix<T>(), w_tensor->matrix<T>(), 982 wci_tensor->vec<T>(), wcf_tensor->vec<T>(), wco_tensor->vec<T>(), 983 b_tensor->vec<T>(), xh_tensor.matrix<T>(), i_tensor.matrix<T>(), 984 cs_tensor.matrix<T>(), f_tensor.matrix<T>(), o_tensor.matrix<T>(), 985 ci_tensor.matrix<T>(), co_tensor.matrix<T>(), icfo_tensor.matrix<T>(), 986 h_tensor.matrix<T>()); 987 slicer.FinishTimeStep(); 988 } 989 990 if (seq_len_max < timelen) { 991 Tensor cs_tensor = cs_out->Slice(seq_len_max, timelen); 992 Tensor h_tensor = h_out->Slice(seq_len_max, timelen); 993 994 functor::TensorUnalignedZero<Device, T>()(device, 995 cs_tensor.unaligned_flat<T>()); 996 functor::TensorUnalignedZero<Device, T>()(device, 997 h_tensor.unaligned_flat<T>()); 998 } 999 } 1000 1001 private: 1002 float forget_bias_; 1003 float cell_clip_; 1004 bool use_peephole_; 1005 }; 1006 1007 #define REGISTER_KERNEL(T) \ 1008 REGISTER_KERNEL_BUILDER( \ 1009 Name("BlockLSTM").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 1010 BlockLSTMOp<CPUDevice, T, false>); 1011 REGISTER_KERNEL(float); 1012 REGISTER_KERNEL(Eigen::half); 1013 #undef REGISTER_KERNEL 1014 1015 #if GOOGLE_CUDA 1016 namespace functor { 1017 #define DECLARE_GPU_SPEC(T) \ 1018 template <> \ 1019 void TensorZero<GPUDevice, T>::operator()(const GPUDevice& d, \ 1020 typename TTypes<T>::Flat t); \ 1021 \ 1022 extern template struct TensorZero<GPUDevice, T>; \ 1023 \ 1024 template <> \ 1025 void TensorUnalignedZero<GPUDevice, T>::operator()( \ 1026 const GPUDevice& d, typename TTypes<T>::UnalignedFlat t); \ 1027 \ 1028 extern template struct TensorUnalignedZero<GPUDevice, T>; 1029 1030 DECLARE_GPU_SPEC(float); 1031 DECLARE_GPU_SPEC(Eigen::half); 1032 // DECLARE_GPU_SPEC(double); 1033 #undef DECLARE_GPU_SPEC 1034 } // end namespace functor 1035 1036 #define REGISTER_GPU_KERNEL(T) \ 1037 REGISTER_KERNEL_BUILDER(Name("BlockLSTM") \ 1038 .Device(DEVICE_GPU) \ 1039 .HostMemory("seq_len_max") \ 1040 .TypeConstraint<T>("T"), \ 1041 BlockLSTMOp<GPUDevice, T, true>); 1042 1043 REGISTER_GPU_KERNEL(float); 1044 REGISTER_GPU_KERNEL(Eigen::half); 1045 // REGISTER_GPU_KERNEL(double); 1046 #undef REGISTER_GPU_KERNEL 1047 #endif // GOOGLE_CUDA 1048 1049 template <typename Device, typename T, bool USE_CUBLAS> 1050 class BlockLSTMGradOp : public OpKernel { 1051 public: 1052 explicit BlockLSTMGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 1053 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_)); 1054 } 1055 1056 void Compute(OpKernelContext* ctx) override { 1057 const Tensor* seq_len_max_tensor = nullptr; 1058 OP_REQUIRES_OK(ctx, ctx->input("seq_len_max", &seq_len_max_tensor)); 1059 1060 const Tensor* x; 1061 OP_REQUIRES_OK(ctx, ctx->input("x", &x)); 1062 OP_REQUIRES(ctx, x->dims() == 3, errors::InvalidArgument("x must be 3D")); 1063 const int64 timelen = x->dim_size(0); 1064 const int64 batch_size = x->dim_size(1); 1065 const int64 input_size = x->dim_size(2); 1066 1067 const Tensor* cs_prev_tensor = nullptr; 1068 OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor)); 1069 1070 const Tensor* h_prev_tensor = nullptr; 1071 OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor)); 1072 1073 const Tensor* w_tensor = nullptr; 1074 OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor)); 1075 const int64 cell_size = w_tensor->dim_size(1) / 4; 1076 OP_REQUIRES(ctx, input_size + cell_size == w_tensor->dim_size(0), 1077 errors::InvalidArgument( 1078 "w matrix rows don't match: ", input_size + cell_size, 1079 " vs. ", w_tensor->dim_size(0))); 1080 1081 const Tensor* wci_tensor = nullptr; 1082 OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor)); 1083 1084 const Tensor* wcf_tensor = nullptr; 1085 OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor)); 1086 1087 const Tensor* wco_tensor = nullptr; 1088 OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor)); 1089 1090 const Tensor* b_tensor = nullptr; 1091 OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor)); 1092 OP_REQUIRES( 1093 ctx, cell_size == b_tensor->dim_size(0) / 4, 1094 errors::InvalidArgument("w and b cell_size don't match: ", cell_size, 1095 " vs. ", b_tensor->dim_size(0))); 1096 1097 const Tensor* i_out = nullptr; 1098 OP_REQUIRES_OK(ctx, ctx->input("i", &i_out)); 1099 1100 const Tensor* cs_out = nullptr; 1101 OP_REQUIRES_OK(ctx, ctx->input("cs", &cs_out)); 1102 1103 const Tensor* f_out = nullptr; 1104 OP_REQUIRES_OK(ctx, ctx->input("f", &f_out)); 1105 1106 const Tensor* o_out = nullptr; 1107 OP_REQUIRES_OK(ctx, ctx->input("o", &o_out)); 1108 1109 const Tensor* ci_out = nullptr; 1110 OP_REQUIRES_OK(ctx, ctx->input("ci", &ci_out)); 1111 1112 const Tensor* co_out = nullptr; 1113 OP_REQUIRES_OK(ctx, ctx->input("co", &co_out)); 1114 1115 const Tensor* h_out = nullptr; 1116 OP_REQUIRES_OK(ctx, ctx->input("h", &h_out)); 1117 1118 const Tensor* cs_grad = nullptr; 1119 OP_REQUIRES_OK(ctx, ctx->input("cs_grad", &cs_grad)); 1120 1121 const Tensor* h_grad = nullptr; 1122 OP_REQUIRES_OK(ctx, ctx->input("h_grad", &h_grad)); 1123 1124 TensorShape batch_input_shape({timelen, batch_size, input_size}); 1125 Tensor* x_grad; 1126 OP_REQUIRES_OK(ctx, 1127 ctx->allocate_output("x_grad", batch_input_shape, &x_grad)); 1128 1129 Tensor* cs_prev_grad_tensor = nullptr; 1130 OP_REQUIRES_OK(ctx, 1131 ctx->allocate_output("cs_prev_grad", cs_prev_tensor->shape(), 1132 &cs_prev_grad_tensor)); 1133 1134 Tensor* h_prev_grad_tensor = nullptr; 1135 OP_REQUIRES_OK(ctx, 1136 ctx->allocate_output("h_prev_grad", h_prev_tensor->shape(), 1137 &h_prev_grad_tensor)); 1138 1139 Tensor* w_grad_tensor = nullptr; 1140 OP_REQUIRES_OK( 1141 ctx, ctx->allocate_output("w_grad", w_tensor->shape(), &w_grad_tensor)); 1142 1143 Tensor* wci_grad_tensor = nullptr; 1144 OP_REQUIRES_OK(ctx, ctx->allocate_output("wci_grad", wci_tensor->shape(), 1145 &wci_grad_tensor)); 1146 1147 Tensor* wcf_grad_tensor = nullptr; 1148 OP_REQUIRES_OK(ctx, ctx->allocate_output("wcf_grad", wcf_tensor->shape(), 1149 &wcf_grad_tensor)); 1150 1151 Tensor* wco_grad_tensor = nullptr; 1152 OP_REQUIRES_OK(ctx, ctx->allocate_output("wco_grad", wco_tensor->shape(), 1153 &wco_grad_tensor)); 1154 1155 Tensor* b_grad_tensor = nullptr; 1156 OP_REQUIRES_OK( 1157 ctx, ctx->allocate_output("b_grad", b_tensor->shape(), &b_grad_tensor)); 1158 1159 TensorShape batch_cell_shape({batch_size, cell_size}); 1160 1161 Tensor xh_tensor; 1162 OP_REQUIRES_OK(ctx, ctx->allocate_temp( 1163 DataTypeToEnum<T>::v(), 1164 TensorShape({batch_size, input_size + cell_size}), 1165 &xh_tensor)); 1166 1167 Tensor xh_grad_tensor; 1168 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 1169 xh_tensor.shape(), &xh_grad_tensor)); 1170 1171 Tensor do_tensor; 1172 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 1173 batch_cell_shape, &do_tensor)); 1174 1175 Tensor dcs_tensor; 1176 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 1177 batch_cell_shape, &dcs_tensor)); 1178 1179 Tensor dci_tensor; 1180 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 1181 batch_cell_shape, &dci_tensor)); 1182 1183 Tensor df_tensor; 1184 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 1185 batch_cell_shape, &df_tensor)); 1186 1187 Tensor di_tensor; 1188 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 1189 batch_cell_shape, &di_tensor)); 1190 1191 Tensor dicfo_tensor; 1192 OP_REQUIRES_OK(ctx, 1193 ctx->allocate_temp(DataTypeToEnum<T>::v(), 1194 TensorShape({batch_size, cell_size * 4}), 1195 &dicfo_tensor)); 1196 1197 Tensor cs_grad_tensor; 1198 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 1199 batch_cell_shape, &cs_grad_tensor)); 1200 1201 Tensor h_grad_tensor; 1202 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(), 1203 batch_cell_shape, &h_grad_tensor)); 1204 1205 const Device& device = ctx->eigen_device<Device>(); 1206 1207 functor::TensorZero<Device, T>()(device, cs_grad_tensor.flat<T>()); 1208 functor::TensorZero<Device, T>()(device, cs_prev_grad_tensor->flat<T>()); 1209 functor::TensorZero<Device, T>()(device, h_grad_tensor.flat<T>()); 1210 functor::TensorZero<Device, T>()(device, h_prev_grad_tensor->flat<T>()); 1211 functor::TensorZero<Device, T>()(device, w_grad_tensor->flat<T>()); 1212 functor::TensorZero<Device, T>()(device, wci_grad_tensor->flat<T>()); 1213 functor::TensorZero<Device, T>()(device, wcf_grad_tensor->flat<T>()); 1214 functor::TensorZero<Device, T>()(device, wco_grad_tensor->flat<T>()); 1215 functor::TensorZero<Device, T>()(device, b_grad_tensor->flat<T>()); 1216 1217 const int64 seq_len_max = seq_len_max_tensor->scalar<int64>()(); 1218 SliceHelper<Device, T> slicer(ctx); 1219 for (int64 t = seq_len_max - 1; t >= 0; --t) { 1220 const Tensor& x_tensor = slicer.InputSlice(*x, t, "x"); 1221 const Tensor& cs_prev_tensor2 = 1222 t == 0 ? *cs_prev_tensor 1223 : slicer.InputSlice(*cs_out, t - 1, "cs_prev"); 1224 const Tensor& h_prev_tensor2 = 1225 t == 0 ? *h_prev_tensor : slicer.InputSlice(*h_out, t - 1, "h_prev"); 1226 const Tensor& i_tensor = slicer.InputSlice(*i_out, t, "i_out"); 1227 const Tensor& cs_tensor = slicer.InputSlice(*cs_out, t, "cs_out"); 1228 const Tensor& f_tensor = slicer.InputSlice(*f_out, t, "f_out"); 1229 const Tensor& o_tensor = slicer.InputSlice(*o_out, t, "o_out"); 1230 const Tensor& ci_tensor = slicer.InputSlice(*ci_out, t, "ci_out"); 1231 const Tensor& co_tensor = slicer.InputSlice(*co_out, t, "co_out"); 1232 1233 // Grab previous CS grad. 1234 const Tensor& const_cs_prev_grad_tensor = *cs_prev_grad_tensor; 1235 const Tensor const_cs_grad_slice = 1236 slicer.InputSlice(*cs_grad, t, "cs_grad"); 1237 functor::TensorAdd<Device, T>()( 1238 device, const_cs_prev_grad_tensor.flat<T>(), 1239 const_cs_grad_slice.flat<T>(), cs_grad_tensor.flat<T>()); 1240 1241 // Combine previous h grad and h grad coming on top. 1242 const Tensor& const_h_prev_grad_tensor = *h_prev_grad_tensor; 1243 const Tensor const_h_grad_slice = slicer.InputSlice(*h_grad, t, "h_grad"); 1244 functor::TensorAdd<Device, T>()( 1245 device, const_h_prev_grad_tensor.flat<T>(), 1246 const_h_grad_slice.flat<T>(), h_grad_tensor.flat<T>()); 1247 1248 const Tensor& const_cs_grad_tensor = cs_grad_tensor; 1249 const Tensor& const_h_grad_tensor = h_grad_tensor; 1250 1251 Tensor x_grad_tensor = slicer.OutputSlice(x_grad, t, "x_grad"); 1252 functor::BlockLSTMBprop<Device, T, USE_CUBLAS>(batch_size, input_size, 1253 cell_size)( 1254 ctx, device, use_peephole_, x_tensor.matrix<T>(), 1255 cs_prev_tensor2.matrix<T>(), h_prev_tensor2.matrix<T>(), 1256 w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(), 1257 wco_tensor->vec<T>(), b_tensor->vec<T>(), xh_tensor.matrix<T>(), 1258 i_tensor.matrix<T>(), cs_tensor.matrix<T>(), f_tensor.matrix<T>(), 1259 o_tensor.matrix<T>(), ci_tensor.matrix<T>(), co_tensor.matrix<T>(), 1260 const_cs_grad_tensor.matrix<T>(), const_h_grad_tensor.matrix<T>(), 1261 do_tensor.matrix<T>(), dcs_tensor.matrix<T>(), dci_tensor.matrix<T>(), 1262 df_tensor.matrix<T>(), di_tensor.matrix<T>(), 1263 dicfo_tensor.matrix<T>(), cs_prev_grad_tensor->matrix<T>(), 1264 h_prev_grad_tensor->matrix<T>(), xh_grad_tensor.matrix<T>(), 1265 x_grad_tensor.matrix<T>(), w_grad_tensor->matrix<T>(), 1266 wci_grad_tensor->vec<T>(), wcf_grad_tensor->vec<T>(), 1267 wco_grad_tensor->vec<T>(), b_grad_tensor->vec<T>()); 1268 slicer.FinishTimeStep(); 1269 } 1270 1271 if (seq_len_max < timelen) { 1272 Tensor x_grad_tensor = x_grad->Slice(seq_len_max, timelen); 1273 functor::TensorUnalignedZero<Device, T>()( 1274 device, x_grad_tensor.unaligned_flat<T>()); 1275 } 1276 } 1277 1278 private: 1279 bool use_peephole_; 1280 }; 1281 1282 #define REGISTER_KERNEL(T) \ 1283 REGISTER_KERNEL_BUILDER( \ 1284 Name("BlockLSTMGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 1285 BlockLSTMGradOp<CPUDevice, T, false>); 1286 REGISTER_KERNEL(float); 1287 REGISTER_KERNEL(Eigen::half); 1288 #undef REGISTER_KERNEL 1289 1290 #if GOOGLE_CUDA 1291 namespace functor { 1292 #define DECLARE_GPU_SPEC(T) \ 1293 template <> \ 1294 void TensorCopy<GPUDevice, T>::operator()(const GPUDevice& d, \ 1295 typename TTypes<T>::ConstFlat src, \ 1296 typename TTypes<T>::Flat dst); \ 1297 \ 1298 template <> \ 1299 void TensorCopyUnaligned<GPUDevice, T>::operator()( \ 1300 const GPUDevice& d, typename TTypes<T>::UnalignedConstFlat src, \ 1301 typename TTypes<T>::Flat dst); \ 1302 \ 1303 template <> \ 1304 void TensorCopyToUnaligned<GPUDevice, T>::operator()( \ 1305 const GPUDevice& d, typename TTypes<T>::ConstFlat src, \ 1306 typename TTypes<T>::UnalignedFlat dst); \ 1307 \ 1308 template <> \ 1309 void TensorAdd<GPUDevice, T>::operator()( \ 1310 const GPUDevice& d, typename TTypes<T>::ConstFlat a, \ 1311 typename TTypes<T>::ConstFlat b, typename TTypes<T>::Flat c); \ 1312 \ 1313 template <> \ 1314 void BlockLSTMBprop<GPUDevice, T, true>::operator()( \ 1315 OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \ 1316 typename TTypes<T>::ConstMatrix x, \ 1317 typename TTypes<T>::ConstMatrix cs_prev, \ 1318 typename TTypes<T>::ConstMatrix h_prev, \ 1319 typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \ 1320 typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco, \ 1321 typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh, \ 1322 typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs, \ 1323 typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o, \ 1324 typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co, \ 1325 typename TTypes<T>::ConstMatrix cs_grad, \ 1326 typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \ 1327 typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci, \ 1328 typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di, \ 1329 typename TTypes<T>::Matrix dicfo, \ 1330 typename TTypes<T>::Matrix cs_prev_grad, \ 1331 typename TTypes<T>::Matrix h_prev_grad, \ 1332 typename TTypes<T>::Matrix xh_grad, typename TTypes<T>::Matrix x_grad, \ 1333 typename TTypes<T>::Matrix w_grad, typename TTypes<T>::Vec wci_grad, \ 1334 typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad, \ 1335 typename TTypes<T>::Vec b_grad); \ 1336 \ 1337 extern template struct TensorCopy<GPUDevice, T>; \ 1338 extern template struct TensorAdd<GPUDevice, T>; \ 1339 extern template struct BlockLSTMBprop<GPUDevice, T, true>; 1340 1341 DECLARE_GPU_SPEC(float); 1342 DECLARE_GPU_SPEC(Eigen::half); 1343 // DECLARE_GPU_SPEC(double); 1344 #undef DECLARE_GPU_SPEC 1345 } // end namespace functor 1346 1347 #define REGISTER_GPU_KERNEL(T) \ 1348 REGISTER_KERNEL_BUILDER(Name("BlockLSTMGrad") \ 1349 .Device(DEVICE_GPU) \ 1350 .HostMemory("seq_len_max") \ 1351 .TypeConstraint<T>("T"), \ 1352 BlockLSTMGradOp<GPUDevice, T, true>); 1353 1354 REGISTER_GPU_KERNEL(float); 1355 REGISTER_GPU_KERNEL(Eigen::half); 1356 // REGISTER_GPU_KERNEL(double); 1357 #undef REGISTER_GPU_KERNEL 1358 #endif // GOOGLE_CUDA 1359 1360 } // end namespace tensorflow 1361