1 # -*- coding: utf-8 -*- 2 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 # 4 # Licensed under the Apache License, Version 2.0 (the "License"); 5 # you may not use this file except in compliance with the License. 6 # You may obtain a copy of the License at 7 # 8 # http://www.apache.org/licenses/LICENSE-2.0 9 # 10 # Unless required by applicable law or agreed to in writing, software 11 # distributed under the License is distributed on an "AS IS" BASIS, 12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 # See the License for the specific language governing permissions and 14 # limitations under the License. 15 # ============================================================================== 16 """Tests for Cudnn RNN models.""" 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import itertools 22 import os 23 import unittest 24 25 import numpy as np 26 27 from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops 28 from tensorflow.core.protobuf import saver_pb2 29 from tensorflow.python.framework import constant_op 30 from tensorflow.python.framework import dtypes 31 from tensorflow.python.framework import ops 32 from tensorflow.python.framework import random_seed 33 from tensorflow.python.framework.test_util import TensorFlowTestCase 34 from tensorflow.python.ops import array_ops 35 from tensorflow.python.ops import gradient_checker 36 from tensorflow.python.ops import math_ops 37 from tensorflow.python.ops import random_ops 38 from tensorflow.python.ops import state_ops 39 from tensorflow.python.ops import variables 40 from tensorflow.python.platform import googletest 41 from tensorflow.python.platform import test 42 from tensorflow.python.platform import tf_logging as logging 43 from tensorflow.python.training import saver as saver_lib 44 45 CUDNN_RNN_UNIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION 46 CUDNN_RNN_BIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION 47 48 CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM 49 CUDNN_GRU = cudnn_rnn_ops.CUDNN_GRU 50 CUDNN_RNN_RELU = cudnn_rnn_ops.CUDNN_RNN_RELU 51 CUDNN_RNN_TANH = cudnn_rnn_ops.CUDNN_RNN_TANH 52 53 CUDNN_LSTM_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_LSTM_PARAMS_PER_LAYER 54 CUDNN_GRU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_GRU_PARAMS_PER_LAYER 55 CUDNN_RNN_TANH_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_TANH_PARAMS_PER_LAYER 56 CUDNN_RNN_RELU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_RELU_PARAMS_PER_LAYER 57 58 59 def _CreateModel(rnn_mode, 60 num_layers, 61 num_units, 62 input_size, 63 input_mode="linear_input", 64 direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION, 65 dtype=dtypes.float32, 66 dropout=0.): 67 del input_mode 68 if rnn_mode == cudnn_rnn_ops.CUDNN_LSTM: 69 model_fn = cudnn_rnn_ops.CudnnLSTM 70 elif rnn_mode == cudnn_rnn_ops.CUDNN_GRU: 71 model_fn = cudnn_rnn_ops.CudnnGRU 72 elif rnn_mode == cudnn_rnn_ops.CUDNN_RNN_TANH: 73 model_fn = cudnn_rnn_ops.CudnnRNNTanh 74 elif rnn_mode == cudnn_rnn_ops.CUDNN_RNN_RELU: 75 model_fn = cudnn_rnn_ops.CudnnRNNRelu 76 else: 77 raise ValueError("Invalid rnn_mode: %s" % rnn_mode) 78 return model_fn( 79 num_layers, 80 num_units, 81 input_size, 82 direction=direction, 83 dtype=dtype, 84 dropout=dropout) 85 86 87 def _CreateParamsSavable(params, 88 model, 89 base_variable_scope=None, 90 name="params_canonical"): 91 """Create a RNNParamsSaveable for the weight and bias parameters. 92 93 Args: 94 params: a Variable for weight and bias parameters. 95 model: a CudnnRNN model. 96 base_variable_scope: a string, prefix of names of saved variables. 97 name: a string, name of the RNNParamsSaveable object. 98 Returns: 99 a RNNParamsSaveable object. 100 """ 101 if model._rnn_mode == CUDNN_LSTM: 102 fn = cudnn_rnn_ops.CudnnLSTMSaveable 103 elif model._rnn_mode == CUDNN_GRU: 104 fn = cudnn_rnn_ops.CudnnGRUSaveable 105 elif model._rnn_mode == CUDNN_RNN_TANH: 106 fn = cudnn_rnn_ops.CudnnRNNTanhSaveable 107 elif model._rnn_mode == CUDNN_RNN_RELU: 108 fn = cudnn_rnn_ops.CudnnRNNReluSaveable 109 params_saveable = fn( 110 params, 111 model.num_layers, 112 model.num_units, 113 model.input_size, 114 model.input_mode, 115 model.direction, 116 scope=base_variable_scope, 117 name=name) 118 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, params_saveable) 119 return params_saveable 120 121 122 def _MinLSTMParamSize(num_layers, 123 num_units, 124 input_size, 125 direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION): 126 if direction == cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION: 127 first_layer_weights = 4 * num_units * (num_units + input_size) 128 higher_layer_weights = 8 * (num_layers - 1) * num_units * num_units 129 all_biases = 8 * num_layers * num_units 130 return first_layer_weights + higher_layer_weights + all_biases 131 elif direction == cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION: 132 first_layer_weights = 4 * num_units * (num_units + input_size) 133 higher_layer_weights = (num_layers - 1) * ( 134 4 * 2 * num_units * num_units + 4 * num_units**2) 135 all_biases = 8 * num_layers * num_units 136 return 2 * (first_layer_weights + higher_layer_weights + all_biases) 137 else: 138 raise ValueError("%s direction is not supported.") 139 140 141 class CudnnRNNTestSaveRestore(TensorFlowTestCase): 142 143 def _CompareWeights(self, lhs, rhs): 144 self.assertEqual(len(lhs), len(rhs)) 145 for lw, rw in zip(lhs, rhs): 146 self.assertAllEqual(lw, rw) 147 148 def _CompareBiases(self, lhs, rhs, rnn_mode, num_layers, direction): 149 self.assertEqual(len(lhs), len(rhs)) 150 if rnn_mode == CUDNN_LSTM: 151 num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER 152 elif rnn_mode == CUDNN_GRU: 153 num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER 154 elif rnn_mode == CUDNN_RNN_TANH: 155 num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER 156 else: 157 num_params_per_layer = CUDNN_RNN_RELU_PARAMS_PER_LAYER 158 num_dirs = 1 if direction == CUDNN_RNN_UNIDIRECTION else 2 159 num_params_per_layer *= num_dirs 160 self.assertEqual(num_params_per_layer * num_layers, len(lhs)) 161 162 for i in range(num_layers): 163 layer_lhs = lhs[i * num_params_per_layer: (i+1) * num_params_per_layer] 164 layer_rhs = rhs[i * num_params_per_layer: (i+1) * num_params_per_layer] 165 if direction == CUDNN_RNN_UNIDIRECTION: 166 self._CompareSingleLayerBiases(layer_lhs, layer_rhs) 167 else: 168 size = len(layer_lhs) 169 fw_lhs, bw_lhs = layer_lhs[:size//2], layer_lhs[size//2:] 170 fw_rhs, bw_rhs = layer_rhs[:size//2], layer_rhs[size//2:] 171 self._CompareSingleLayerBiases(fw_lhs, fw_rhs) 172 self._CompareSingleLayerBiases(bw_lhs, bw_rhs) 173 174 def _CompareSingleLayerBiases(self, lhs, rhs): 175 self.assertEqual(len(lhs), len(rhs)) 176 177 lf_lhs, rt_lhs = lhs[:len(lhs)//2], lhs[len(lhs)//2:] 178 lf_rhs, rt_rhs = rhs[:len(rhs)//2], rhs[len(rhs)//2:] 179 self.assertEqual(len(lf_lhs), len(rt_lhs)) 180 self.assertEqual(len(lf_rhs), len(rt_rhs)) 181 182 sum_lhs, sum_rhs = [], [] 183 for lf, rt in zip(lf_lhs, rt_lhs): 184 sum_lhs.append(lf + rt) 185 for lf, rt in zip(lf_rhs, rt_rhs): 186 sum_rhs.append(lf + rt) 187 self.assertEqual(len(sum_lhs), len(sum_rhs)) 188 for lf, rt in zip(sum_lhs, sum_rhs): 189 self.assertAllEqual(lf, rt) 190 191 def _testSaveRestoreVariable(self, rnn_mode, direction, dtype): 192 num_layers = 2 193 num_units = 7 194 input_size = 3 195 with ops.Graph().as_default(): 196 model = _CreateModel( 197 rnn_mode, 198 num_layers=num_layers, 199 num_units=num_units, 200 input_size=input_size, 201 direction=direction, 202 dtype=dtype) 203 random_seed.set_random_seed(1234) 204 params_size_t = model.params_size() 205 params = variables.Variable( 206 random_ops.random_uniform([params_size_t], dtype=dtype), 207 dtype=dtype, 208 validate_shape=False) 209 saveable = _CreateParamsSavable(params, model) 210 weights, biases = saveable._OpaqueParamsToCanonical() 211 reset_params = state_ops.assign( 212 params, 213 array_ops.zeros([params_size_t], dtype=dtype), 214 validate_shape=False) 215 save_path = os.path.join(self.get_temp_dir(), 216 "save-restore-variable-test") 217 saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) 218 # Passing graph explicitly, otherwise an old sess would be reused. 219 with self.test_session( 220 use_gpu=True, graph=ops.get_default_graph()) as sess: 221 sess.run(variables.global_variables_initializer()) 222 val = saver.save(sess, save_path) 223 self.assertEqual(save_path, val) 224 225 weights_v, biases_v = sess.run([weights, biases]) 226 227 sess.run(reset_params) 228 saver.restore(sess, save_path) 229 weights_v_restored, biases_v_restored = sess.run([weights, biases]) 230 231 self._CompareWeights(weights_v, weights_v_restored) 232 self._CompareBiases(biases_v, biases_v_restored, rnn_mode, num_layers, 233 direction) 234 235 def _testSaveRestoreTwoVariables(self, rnn_mode, direction, dtype): 236 num_layers = 2 237 num_units = 7 238 input_size = 3 239 with ops.Graph().as_default(): 240 model = _CreateModel( 241 rnn_mode, 242 num_layers=num_layers, 243 num_units=num_units, 244 input_size=input_size, 245 direction=direction, 246 dtype=dtype) 247 random_seed.set_random_seed(1234) 248 params_size_t = model.params_size() 249 names = ["rnn_1", "rnn_2"] 250 param_vars = [ 251 variables.Variable( 252 random_ops.random_uniform([params_size_t], dtype=dtype), 253 dtype=dtype, 254 validate_shape=False) for name in names 255 ] 256 saveables = [] 257 for name, params in zip(names, param_vars): 258 saveables.append(_CreateParamsSavable(params, model, name, name)) 259 weights1, biases1 = saveables[0]._OpaqueParamsToCanonical() 260 weights2, biases2 = saveables[1]._OpaqueParamsToCanonical() 261 reset_params = [ 262 state_ops.assign( 263 params, 264 array_ops.zeros([params_size_t], dtype=dtype), 265 validate_shape=False) for params in param_vars 266 ] 267 save_path = os.path.join(self.get_temp_dir(), 268 "save-restore-variable-test") 269 saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) 270 # Passing graph explicitly, otherwise an old sess would be reused. 271 with self.test_session(use_gpu=True, 272 graph=ops.get_default_graph()) as sess: 273 sess.run(variables.global_variables_initializer()) 274 val = saver.save(sess, save_path) 275 self.assertEqual(save_path, val) 276 weights1_v, biases1_v = sess.run([weights1, biases1]) 277 weights2_v, biases2_v = sess.run([weights2, biases2]) 278 279 sess.run(reset_params) 280 saver.restore(sess, save_path) 281 weights1_v_restored, biases1_v_restored = sess.run([weights1, biases1]) 282 weights2_v_restored, biases2_v_restored = sess.run([weights2, biases2]) 283 284 self._CompareWeights(weights1_v, weights1_v_restored) 285 self._CompareWeights(weights2_v, weights2_v_restored) 286 self._CompareBiases(biases1_v, biases1_v_restored, rnn_mode, num_layers, 287 direction) 288 self._CompareBiases(biases2_v, biases2_v_restored, rnn_mode, num_layers, 289 direction) 290 291 def _testSaveRestoreOutput(self, rnn_mode, direction, dtype): 292 with ops.Graph().as_default(): 293 num_layers = 2 294 num_units = 7 295 input_size = 7 296 seq_length = 10 297 batch_size = 5 298 dir_count = 1 if direction == cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION else 2 299 model = _CreateModel( 300 rnn_mode, 301 num_layers, 302 num_units, 303 input_size, 304 direction=direction, 305 dtype=dtype) 306 params_size_t = model.params_size() 307 params = variables.Variable( 308 array_ops.ones([params_size_t], dtype=dtype), 309 validate_shape=False, 310 dtype=dtype) 311 _CreateParamsSavable(params, model) 312 save_path = os.path.join(self.get_temp_dir(), "save-restore-output-test") 313 saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) 314 315 np.random.seed(1234) 316 has_input_c = (rnn_mode == cudnn_rnn_ops.CUDNN_LSTM) 317 input_data = constant_op.constant( 318 np.random.randn(seq_length, batch_size, input_size), dtype=dtype) 319 input_h = constant_op.constant( 320 np.random.randn(num_layers * dir_count, batch_size, num_units), 321 dtype=dtype) 322 if has_input_c: 323 input_c = constant_op.constant( 324 np.random.randn(num_layers * dir_count, batch_size, num_units), 325 dtype=dtype) 326 outputs = model( 327 input_data=input_data, 328 input_h=input_h, 329 input_c=input_c, 330 params=params, 331 is_training=False) 332 else: 333 outputs = model( 334 input_data=input_data, 335 input_h=input_h, 336 params=params, 337 is_training=False) 338 total_sum = sum(map(math_ops.reduce_sum, outputs)) 339 # Passing graph explicitly, otherwise an old sess would be reused. 340 with self.test_session( 341 use_gpu=True, graph=ops.get_default_graph()) as sess: 342 sess.run(variables.global_variables_initializer()) 343 total_sum_v = sess.run(total_sum) 344 val = saver.save(sess, save_path) 345 self.assertEqual(save_path, val) 346 # Passing graph explicitly, otherwise an old sess would be reused. 347 with self.test_session( 348 use_gpu=True, graph=ops.get_default_graph()) as sess: 349 reset_params = state_ops.assign( 350 params, 351 array_ops.zeros([params_size_t], dtype=dtype), 352 validate_shape=False) 353 sess.run(reset_params) 354 saver.restore(sess, save_path) 355 total_sum_v_restored = sess.run(total_sum) 356 self.assertAllClose(total_sum_v, total_sum_v_restored, atol=1e-5) 357 358 @unittest.skipUnless(test.is_built_with_cuda(), 359 "Test only applicable when running on GPUs") 360 def testSaveRestore(self): 361 rnn_modes = [ 362 cudnn_rnn_ops.CUDNN_LSTM, cudnn_rnn_ops.CUDNN_GRU, 363 cudnn_rnn_ops.CUDNN_RNN_TANH, cudnn_rnn_ops.CUDNN_RNN_RELU 364 ] 365 directions = [ 366 cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION, 367 cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION 368 ] 369 dtype_list = [dtypes.float32, dtypes.float64] 370 for rnn_mode, direction, dtype in itertools.product(rnn_modes, directions, 371 dtype_list): 372 self._testSaveRestoreVariable(rnn_mode, direction, dtype) 373 self._testSaveRestoreTwoVariables(rnn_mode, direction, dtype) 374 self._testSaveRestoreOutput(rnn_mode, direction, dtype) 375 376 377 class CudnnRNNTestParamsSize(TensorFlowTestCase): 378 379 def _testOneLSTMParamsSize(self, num_layers, num_units, input_size, 380 direction): 381 logging.info("Testing one lstm param size with config: %s", locals()) 382 min_params_size = _MinLSTMParamSize(num_layers, num_units, input_size, 383 direction) 384 model = _CreateModel( 385 cudnn_rnn_ops.CUDNN_LSTM, 386 num_layers, 387 num_units, 388 input_size, 389 direction=direction) 390 params_size = model.params_size() 391 with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess: 392 params_size_v = sess.run(params_size) 393 self.assertLessEqual(min_params_size, params_size_v) 394 395 @unittest.skipUnless(test.is_built_with_cuda(), 396 "Test only applicable when running on GPUs") 397 def testLSTMParamsSize(self): 398 test_configs = [ 399 [4, 200, 200], 400 [4, 200, 300], 401 [4, 200, 100], 402 [1, 100, 200], 403 [2, 200, 100], 404 [3, 200, 400], 405 ] 406 directions = [ 407 cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION, 408 cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION 409 ] 410 for (config, direction) in itertools.product(test_configs, directions): 411 num_layers, num_units, input_size = config 412 with ops.Graph().as_default(): 413 self._testOneLSTMParamsSize(num_layers, num_units, input_size, 414 direction) 415 416 417 class CudnnRNNTestInference(TensorFlowTestCase): 418 419 def _testOneSimpleInference(self, rnn_mode, num_layers, num_units, input_size, 420 batch_size, seq_length, dir_count, dropout, 421 expected, tolerance): 422 random_seed.set_random_seed(5678) 423 model = _CreateModel( 424 rnn_mode, 425 num_layers, 426 num_units, 427 input_size, 428 input_mode="auto_select", 429 direction=(cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION if dir_count == 1 430 else cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION), 431 dropout=dropout) 432 has_input_c = (rnn_mode == cudnn_rnn_ops.CUDNN_LSTM) 433 params_size_t = model.params_size() 434 input_data = array_ops.ones([seq_length, batch_size, input_size]) 435 input_h = array_ops.ones([num_layers * dir_count, batch_size, num_units]) 436 params = variables.Variable( 437 array_ops.ones([params_size_t]), validate_shape=False) 438 if has_input_c: 439 input_c = array_ops.ones([num_layers * dir_count, batch_size, num_units]) 440 output, output_h, output_c = model( 441 input_data=input_data, 442 input_h=input_h, 443 input_c=input_c, 444 params=params, 445 is_training=False) 446 else: 447 output, output_h = model( 448 input_data=input_data, 449 input_h=input_h, 450 params=params, 451 is_training=False) 452 output_sum = math_ops.reduce_sum(output) 453 output_h_sum = math_ops.reduce_sum(output_h) 454 total_sum = output_sum + output_h_sum 455 if has_input_c: 456 output_c_sum = math_ops.reduce_sum(output_c) 457 total_sum += output_c_sum 458 with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess: 459 sess.run(variables.global_variables_initializer()) 460 total_sum_v = sess.run([total_sum]) 461 462 self.assertAllClose( 463 total_sum_v[0], expected, atol=tolerance, rtol=tolerance) 464 465 @unittest.skipUnless(test.is_built_with_cuda(), 466 "Test only applicable when running on GPUs") 467 def testSimpleInference(self): 468 test_configs = [ 469 { 470 "rnn_mode": cudnn_rnn_ops.CUDNN_LSTM, 471 "expected": 231833.22, 472 "tolerance": 1e-2, 473 "shape": { 474 "num_layers": 4, 475 "num_units": 200, 476 "input_size": 200, 477 "batch_size": 20, 478 "seq_length": 10, 479 "dir_count": 1, 480 }, 481 }, 482 { 483 "rnn_mode": cudnn_rnn_ops.CUDNN_GRU, 484 "expected": 56000, 485 "tolerance": 1e-2, 486 "shape": { 487 "num_layers": 4, 488 "num_units": 200, 489 "input_size": 200, 490 "batch_size": 20, 491 "seq_length": 10, 492 "dir_count": 1, 493 }, 494 }, 495 { 496 "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_TANH, 497 "expected": 56000, 498 "tolerance": 1e-2, 499 "shape": { 500 "num_layers": 4, 501 "num_units": 200, 502 "input_size": 200, 503 "batch_size": 20, 504 "seq_length": 10, 505 "dir_count": 1, 506 }, 507 }, 508 { 509 "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_RELU, 510 "expected": 130688, 511 "tolerance": 1e-2, 512 "shape": { 513 "num_layers": 2, 514 "num_units": 8, 515 "input_size": 4, 516 "batch_size": 4, 517 "seq_length": 2, 518 "dir_count": 1, 519 }, 520 }, 521 ] 522 # Cudnn scales result for dropout during training, therefore dropout has no 523 # impact for inference results. 524 # (lstm, gru, rnn_tanh are saturated in the test. rnn_relu case is most 525 # demonstrative of the dropout-invariant nature of CudnnRnn.) 526 dropouts = [0., 0.5, 1.] 527 for (config, dropout) in itertools.product(test_configs, dropouts): 528 rnn_mode = config["rnn_mode"] 529 expected = config["expected"] 530 tolerance = config["tolerance"] 531 shape = config["shape"] 532 with ops.Graph().as_default(): 533 self._testOneSimpleInference( 534 rnn_mode, shape["num_layers"], shape["num_units"], 535 shape["input_size"], shape["batch_size"], shape["seq_length"], 536 shape["dir_count"], dropout, expected, tolerance) 537 538 539 class CudnnRNNTestTraining(TensorFlowTestCase): 540 541 def _testOneSimpleTraining(self, rnn_mode, num_layers, num_units, input_size, 542 batch_size, seq_length, dir_count, dropout, dtype, 543 delta, tolerance): 544 # Gradient checking runs two forward ops with almost the same input. Need to 545 # make sure the drop patterns across the two runs are the same. 546 logging.info("Training test with config: %s", locals()) 547 old_env_state = os.environ.get("TF_CUDNN_RESET_RND_GEN_STATE", str(False)) 548 os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = str(True) 549 has_input_c = (rnn_mode == cudnn_rnn_ops.CUDNN_LSTM) 550 random_seed.set_random_seed(5678) 551 direction = (cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION if dir_count == 1 552 else cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION) 553 model = _CreateModel( 554 rnn_mode, 555 num_layers, 556 num_units, 557 input_size, 558 direction=direction, 559 dtype=dtype, 560 dropout=dropout) 561 params_size_t = model.params_size() 562 input_data = variables.Variable( 563 random_ops.random_uniform( 564 [seq_length, batch_size, input_size], dtype=dtype), 565 dtype=dtype) 566 input_h = variables.Variable( 567 random_ops.random_uniform( 568 [num_layers * dir_count, batch_size, num_units], dtype=dtype), 569 dtype=dtype) 570 params = variables.Variable( 571 random_ops.random_uniform([params_size_t], dtype=dtype), 572 validate_shape=False, 573 dtype=dtype) 574 if has_input_c: 575 input_c = variables.Variable( 576 random_ops.random_uniform( 577 [num_layers * dir_count, batch_size, num_units], dtype=dtype), 578 dtype=dtype) 579 580 output, output_h, output_c = model( 581 input_data=input_data, 582 input_h=input_h, 583 input_c=input_c, 584 params=params) 585 else: 586 output, output_h = model( 587 input_data=input_data, input_h=input_h, params=params) 588 output_sum = math_ops.reduce_sum(output) 589 output_h_sum = math_ops.reduce_sum(output_h) 590 total_sum = output_sum + output_h_sum 591 if has_input_c: 592 output_c_sum = math_ops.reduce_sum(output_c) 593 total_sum += output_c_sum 594 595 with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess: 596 params_size_v = sess.run(params_size_t) 597 inputs_and_shapes = [ 598 (input_data, [seq_length, batch_size, input_size]), 599 (input_h, [num_layers * dir_count, batch_size, num_units]), 600 (params, [params_size_v]), 601 ] 602 if has_input_c: 603 inputs_and_shapes.append( 604 (input_c, [num_layers * dir_count, batch_size, num_units]),) 605 sess.run(variables.global_variables_initializer()) 606 all_inputs = [entry[0] for entry in inputs_and_shapes] 607 all_shapes = [entry[1] for entry in inputs_and_shapes] 608 609 err = gradient_checker.compute_gradient_error( 610 all_inputs, all_shapes, total_sum, [1], delta=delta) 611 612 self.assertLess(err, tolerance) 613 os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = old_env_state 614 615 @unittest.skipUnless(test.is_built_with_cuda(), 616 "Test only applicable when running on GPUs") 617 def testSimpleTraining(self): 618 test_configs = [ 619 { 620 "rnn_mode": cudnn_rnn_ops.CUDNN_LSTM, 621 "dtype": dtypes.float64, 622 "delta": 1e-4, 623 "tolerance": 5e-6, 624 "shape": { 625 "num_layers": 2, 626 "num_units": 3, 627 "input_size": 4, 628 "batch_size": 3, 629 "seq_length": 4, 630 "dir_count": 1, 631 }, 632 }, 633 { 634 "rnn_mode": cudnn_rnn_ops.CUDNN_GRU, 635 "dtype": dtypes.float64, 636 "delta": 1e-4, 637 "tolerance": 5e-6, 638 "shape": { 639 "num_layers": 2, 640 "num_units": 3, 641 "input_size": 4, 642 "batch_size": 3, 643 "seq_length": 4, 644 "dir_count": 1, 645 }, 646 }, 647 { 648 "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_TANH, 649 "dtype": dtypes.float64, 650 "delta": 1e-4, 651 "tolerance": 5e-6, 652 "shape": { 653 "num_layers": 2, 654 "num_units": 3, 655 "input_size": 4, 656 "batch_size": 3, 657 "seq_length": 4, 658 "dir_count": 1, 659 }, 660 }, 661 { 662 "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_RELU, 663 "dtype": dtypes.float64, 664 "delta": 1e-4, 665 "tolerance": 5e-6, 666 "shape": { 667 "num_layers": 2, 668 "num_units": 3, 669 "input_size": 4, 670 "batch_size": 3, 671 "seq_length": 4, 672 "dir_count": 1, 673 }, 674 }, 675 { 676 "rnn_mode": cudnn_rnn_ops.CUDNN_LSTM, 677 "dtype": dtypes.float32, 678 "tolerance": 1.5e-2, 679 "shape": { 680 "num_layers": 2, 681 "num_units": 3, 682 "input_size": 4, 683 "batch_size": 3, 684 "seq_length": 4, 685 }, 686 }, 687 { 688 "rnn_mode": cudnn_rnn_ops.CUDNN_GRU, 689 "dtype": dtypes.float32, 690 "tolerance": 4e-3, 691 "shape": { 692 "num_layers": 2, 693 "num_units": 3, 694 "input_size": 4, 695 "batch_size": 3, 696 "seq_length": 4, 697 }, 698 }, 699 { 700 "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_TANH, 701 "dtype": dtypes.float32, 702 "tolerance": 5e-3, 703 "shape": { 704 "num_layers": 2, 705 "num_units": 3, 706 "input_size": 4, 707 "batch_size": 3, 708 "seq_length": 4, 709 }, 710 }, 711 { 712 "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_RELU, 713 "dtype": dtypes.float32, 714 "tolerance": 5e-1, 715 "shape": { 716 "num_layers": 2, 717 "num_units": 3, 718 "input_size": 4, 719 "batch_size": 3, 720 "seq_length": 4, 721 }, 722 }, 723 ] 724 dropouts = [0., 0.5, 1.] 725 dir_counts = [1] 726 for config, dropout, dir_count in itertools.product(test_configs, dropouts, 727 dir_counts): 728 rnn_mode = config["rnn_mode"] 729 dtype = config.get("dtype", dtypes.float32) 730 delta = config.get("delta", 1e-3) 731 tolerance = config["tolerance"] 732 shape = config["shape"] 733 with ops.Graph().as_default(): 734 self._testOneSimpleTraining(rnn_mode, shape["num_layers"], 735 shape["num_units"], shape["input_size"], 736 shape["batch_size"], shape["seq_length"], 737 dir_count, dropout, dtype, delta, tolerance) 738 739 740 if __name__ == "__main__": 741 googletest.main() 742