Home | History | Annotate | Download | only in kernel_tests
      1 # -*- coding: utf-8 -*-
      2 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      3 #
      4 # Licensed under the Apache License, Version 2.0 (the "License");
      5 # you may not use this file except in compliance with the License.
      6 # You may obtain a copy of the License at
      7 #
      8 #     http://www.apache.org/licenses/LICENSE-2.0
      9 #
     10 # Unless required by applicable law or agreed to in writing, software
     11 # distributed under the License is distributed on an "AS IS" BASIS,
     12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     15 # ==============================================================================
     16 """Tests for Cudnn RNN models."""
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import itertools
     22 import os
     23 import unittest
     24 
     25 import numpy as np
     26 
     27 from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops
     28 from tensorflow.core.protobuf import saver_pb2
     29 from tensorflow.python.framework import constant_op
     30 from tensorflow.python.framework import dtypes
     31 from tensorflow.python.framework import ops
     32 from tensorflow.python.framework import random_seed
     33 from tensorflow.python.framework.test_util import TensorFlowTestCase
     34 from tensorflow.python.ops import array_ops
     35 from tensorflow.python.ops import gradient_checker
     36 from tensorflow.python.ops import math_ops
     37 from tensorflow.python.ops import random_ops
     38 from tensorflow.python.ops import state_ops
     39 from tensorflow.python.ops import variables
     40 from tensorflow.python.platform import googletest
     41 from tensorflow.python.platform import test
     42 from tensorflow.python.platform import tf_logging as logging
     43 from tensorflow.python.training import saver as saver_lib
     44 
     45 CUDNN_RNN_UNIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION
     46 CUDNN_RNN_BIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION
     47 
     48 CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM
     49 CUDNN_GRU = cudnn_rnn_ops.CUDNN_GRU
     50 CUDNN_RNN_RELU = cudnn_rnn_ops.CUDNN_RNN_RELU
     51 CUDNN_RNN_TANH = cudnn_rnn_ops.CUDNN_RNN_TANH
     52 
     53 CUDNN_LSTM_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_LSTM_PARAMS_PER_LAYER
     54 CUDNN_GRU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_GRU_PARAMS_PER_LAYER
     55 CUDNN_RNN_TANH_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_TANH_PARAMS_PER_LAYER
     56 CUDNN_RNN_RELU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_RELU_PARAMS_PER_LAYER
     57 
     58 
     59 def _CreateModel(rnn_mode,
     60                  num_layers,
     61                  num_units,
     62                  input_size,
     63                  input_mode="linear_input",
     64                  direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION,
     65                  dtype=dtypes.float32,
     66                  dropout=0.):
     67   del input_mode
     68   if rnn_mode == cudnn_rnn_ops.CUDNN_LSTM:
     69     model_fn = cudnn_rnn_ops.CudnnLSTM
     70   elif rnn_mode == cudnn_rnn_ops.CUDNN_GRU:
     71     model_fn = cudnn_rnn_ops.CudnnGRU
     72   elif rnn_mode == cudnn_rnn_ops.CUDNN_RNN_TANH:
     73     model_fn = cudnn_rnn_ops.CudnnRNNTanh
     74   elif rnn_mode == cudnn_rnn_ops.CUDNN_RNN_RELU:
     75     model_fn = cudnn_rnn_ops.CudnnRNNRelu
     76   else:
     77     raise ValueError("Invalid rnn_mode: %s" % rnn_mode)
     78   return model_fn(
     79       num_layers,
     80       num_units,
     81       input_size,
     82       direction=direction,
     83       dtype=dtype,
     84       dropout=dropout)
     85 
     86 
     87 def _CreateParamsSavable(params,
     88                          model,
     89                          base_variable_scope=None,
     90                          name="params_canonical"):
     91   """Create a RNNParamsSaveable for the weight and bias parameters.
     92 
     93   Args:
     94     params: a Variable for weight and bias parameters.
     95     model: a CudnnRNN model.
     96     base_variable_scope: a string, prefix of names of saved variables.
     97     name: a string, name of the RNNParamsSaveable object.
     98   Returns:
     99     a RNNParamsSaveable object.
    100   """
    101   if model._rnn_mode == CUDNN_LSTM:
    102     fn = cudnn_rnn_ops.CudnnLSTMSaveable
    103   elif model._rnn_mode == CUDNN_GRU:
    104     fn = cudnn_rnn_ops.CudnnGRUSaveable
    105   elif model._rnn_mode == CUDNN_RNN_TANH:
    106     fn = cudnn_rnn_ops.CudnnRNNTanhSaveable
    107   elif model._rnn_mode == CUDNN_RNN_RELU:
    108     fn = cudnn_rnn_ops.CudnnRNNReluSaveable
    109   params_saveable = fn(
    110       params,
    111       model.num_layers,
    112       model.num_units,
    113       model.input_size,
    114       model.input_mode,
    115       model.direction,
    116       scope=base_variable_scope,
    117       name=name)
    118   ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, params_saveable)
    119   return params_saveable
    120 
    121 
    122 def _MinLSTMParamSize(num_layers,
    123                       num_units,
    124                       input_size,
    125                       direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION):
    126   if direction == cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION:
    127     first_layer_weights = 4 * num_units * (num_units + input_size)
    128     higher_layer_weights = 8 * (num_layers - 1) * num_units * num_units
    129     all_biases = 8 * num_layers * num_units
    130     return first_layer_weights + higher_layer_weights + all_biases
    131   elif direction == cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION:
    132     first_layer_weights = 4 * num_units * (num_units + input_size)
    133     higher_layer_weights = (num_layers - 1) * (
    134         4 * 2 * num_units * num_units + 4 * num_units**2)
    135     all_biases = 8 * num_layers * num_units
    136     return 2 * (first_layer_weights + higher_layer_weights + all_biases)
    137   else:
    138     raise ValueError("%s direction is not supported.")
    139 
    140 
    141 class CudnnRNNTestSaveRestore(TensorFlowTestCase):
    142 
    143   def _CompareWeights(self, lhs, rhs):
    144     self.assertEqual(len(lhs), len(rhs))
    145     for lw, rw in zip(lhs, rhs):
    146       self.assertAllEqual(lw, rw)
    147 
    148   def _CompareBiases(self, lhs, rhs, rnn_mode, num_layers, direction):
    149     self.assertEqual(len(lhs), len(rhs))
    150     if rnn_mode == CUDNN_LSTM:
    151       num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER
    152     elif rnn_mode == CUDNN_GRU:
    153       num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER
    154     elif rnn_mode == CUDNN_RNN_TANH:
    155       num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER
    156     else:
    157       num_params_per_layer = CUDNN_RNN_RELU_PARAMS_PER_LAYER
    158     num_dirs = 1 if direction == CUDNN_RNN_UNIDIRECTION else 2
    159     num_params_per_layer *= num_dirs
    160     self.assertEqual(num_params_per_layer * num_layers, len(lhs))
    161 
    162     for i in range(num_layers):
    163       layer_lhs = lhs[i * num_params_per_layer: (i+1) * num_params_per_layer]
    164       layer_rhs = rhs[i * num_params_per_layer: (i+1) * num_params_per_layer]
    165       if direction == CUDNN_RNN_UNIDIRECTION:
    166         self._CompareSingleLayerBiases(layer_lhs, layer_rhs)
    167       else:
    168         size = len(layer_lhs)
    169         fw_lhs, bw_lhs = layer_lhs[:size//2], layer_lhs[size//2:]
    170         fw_rhs, bw_rhs = layer_rhs[:size//2], layer_rhs[size//2:]
    171         self._CompareSingleLayerBiases(fw_lhs, fw_rhs)
    172         self._CompareSingleLayerBiases(bw_lhs, bw_rhs)
    173 
    174   def _CompareSingleLayerBiases(self, lhs, rhs):
    175     self.assertEqual(len(lhs), len(rhs))
    176 
    177     lf_lhs, rt_lhs = lhs[:len(lhs)//2], lhs[len(lhs)//2:]
    178     lf_rhs, rt_rhs = rhs[:len(rhs)//2], rhs[len(rhs)//2:]
    179     self.assertEqual(len(lf_lhs), len(rt_lhs))
    180     self.assertEqual(len(lf_rhs), len(rt_rhs))
    181 
    182     sum_lhs, sum_rhs = [], []
    183     for lf, rt in zip(lf_lhs, rt_lhs):
    184       sum_lhs.append(lf + rt)
    185     for lf, rt in zip(lf_rhs, rt_rhs):
    186       sum_rhs.append(lf + rt)
    187     self.assertEqual(len(sum_lhs), len(sum_rhs))
    188     for lf, rt in zip(sum_lhs, sum_rhs):
    189       self.assertAllEqual(lf, rt)
    190 
    191   def _testSaveRestoreVariable(self, rnn_mode, direction, dtype):
    192     num_layers = 2
    193     num_units = 7
    194     input_size = 3
    195     with ops.Graph().as_default():
    196       model = _CreateModel(
    197           rnn_mode,
    198           num_layers=num_layers,
    199           num_units=num_units,
    200           input_size=input_size,
    201           direction=direction,
    202           dtype=dtype)
    203       random_seed.set_random_seed(1234)
    204       params_size_t = model.params_size()
    205       params = variables.Variable(
    206           random_ops.random_uniform([params_size_t], dtype=dtype),
    207           dtype=dtype,
    208           validate_shape=False)
    209       saveable = _CreateParamsSavable(params, model)
    210       weights, biases = saveable._OpaqueParamsToCanonical()
    211       reset_params = state_ops.assign(
    212           params,
    213           array_ops.zeros([params_size_t], dtype=dtype),
    214           validate_shape=False)
    215       save_path = os.path.join(self.get_temp_dir(),
    216                                "save-restore-variable-test")
    217       saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2)
    218       # Passing graph explicitly, otherwise an old sess would be reused.
    219       with self.test_session(
    220           use_gpu=True, graph=ops.get_default_graph()) as sess:
    221         sess.run(variables.global_variables_initializer())
    222         val = saver.save(sess, save_path)
    223         self.assertEqual(save_path, val)
    224 
    225         weights_v, biases_v = sess.run([weights, biases])
    226 
    227         sess.run(reset_params)
    228         saver.restore(sess, save_path)
    229         weights_v_restored, biases_v_restored = sess.run([weights, biases])
    230 
    231         self._CompareWeights(weights_v, weights_v_restored)
    232         self._CompareBiases(biases_v, biases_v_restored, rnn_mode, num_layers,
    233                             direction)
    234 
    235   def _testSaveRestoreTwoVariables(self, rnn_mode, direction, dtype):
    236     num_layers = 2
    237     num_units = 7
    238     input_size = 3
    239     with ops.Graph().as_default():
    240       model = _CreateModel(
    241           rnn_mode,
    242           num_layers=num_layers,
    243           num_units=num_units,
    244           input_size=input_size,
    245           direction=direction,
    246           dtype=dtype)
    247       random_seed.set_random_seed(1234)
    248       params_size_t = model.params_size()
    249       names = ["rnn_1", "rnn_2"]
    250       param_vars = [
    251           variables.Variable(
    252               random_ops.random_uniform([params_size_t], dtype=dtype),
    253               dtype=dtype,
    254               validate_shape=False) for name in names
    255       ]
    256       saveables = []
    257       for name, params in zip(names, param_vars):
    258         saveables.append(_CreateParamsSavable(params, model, name, name))
    259       weights1, biases1 = saveables[0]._OpaqueParamsToCanonical()
    260       weights2, biases2 = saveables[1]._OpaqueParamsToCanonical()
    261       reset_params = [
    262           state_ops.assign(
    263               params,
    264               array_ops.zeros([params_size_t], dtype=dtype),
    265               validate_shape=False) for params in param_vars
    266       ]
    267       save_path = os.path.join(self.get_temp_dir(),
    268                                "save-restore-variable-test")
    269       saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2)
    270       # Passing graph explicitly, otherwise an old sess would be reused.
    271       with self.test_session(use_gpu=True,
    272                              graph=ops.get_default_graph()) as sess:
    273         sess.run(variables.global_variables_initializer())
    274         val = saver.save(sess, save_path)
    275         self.assertEqual(save_path, val)
    276         weights1_v, biases1_v = sess.run([weights1, biases1])
    277         weights2_v, biases2_v = sess.run([weights2, biases2])
    278 
    279         sess.run(reset_params)
    280         saver.restore(sess, save_path)
    281         weights1_v_restored, biases1_v_restored = sess.run([weights1, biases1])
    282         weights2_v_restored, biases2_v_restored = sess.run([weights2, biases2])
    283 
    284         self._CompareWeights(weights1_v, weights1_v_restored)
    285         self._CompareWeights(weights2_v, weights2_v_restored)
    286         self._CompareBiases(biases1_v, biases1_v_restored, rnn_mode, num_layers,
    287                             direction)
    288         self._CompareBiases(biases2_v, biases2_v_restored, rnn_mode, num_layers,
    289                             direction)
    290 
    291   def _testSaveRestoreOutput(self, rnn_mode, direction, dtype):
    292     with ops.Graph().as_default():
    293       num_layers = 2
    294       num_units = 7
    295       input_size = 7
    296       seq_length = 10
    297       batch_size = 5
    298       dir_count = 1 if direction == cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION else 2
    299       model = _CreateModel(
    300           rnn_mode,
    301           num_layers,
    302           num_units,
    303           input_size,
    304           direction=direction,
    305           dtype=dtype)
    306       params_size_t = model.params_size()
    307       params = variables.Variable(
    308           array_ops.ones([params_size_t], dtype=dtype),
    309           validate_shape=False,
    310           dtype=dtype)
    311       _CreateParamsSavable(params, model)
    312       save_path = os.path.join(self.get_temp_dir(), "save-restore-output-test")
    313       saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2)
    314 
    315       np.random.seed(1234)
    316       has_input_c = (rnn_mode == cudnn_rnn_ops.CUDNN_LSTM)
    317       input_data = constant_op.constant(
    318           np.random.randn(seq_length, batch_size, input_size), dtype=dtype)
    319       input_h = constant_op.constant(
    320           np.random.randn(num_layers * dir_count, batch_size, num_units),
    321           dtype=dtype)
    322       if has_input_c:
    323         input_c = constant_op.constant(
    324             np.random.randn(num_layers * dir_count, batch_size, num_units),
    325             dtype=dtype)
    326         outputs = model(
    327             input_data=input_data,
    328             input_h=input_h,
    329             input_c=input_c,
    330             params=params,
    331             is_training=False)
    332       else:
    333         outputs = model(
    334             input_data=input_data,
    335             input_h=input_h,
    336             params=params,
    337             is_training=False)
    338       total_sum = sum(map(math_ops.reduce_sum, outputs))
    339       # Passing graph explicitly, otherwise an old sess would be reused.
    340       with self.test_session(
    341           use_gpu=True, graph=ops.get_default_graph()) as sess:
    342         sess.run(variables.global_variables_initializer())
    343         total_sum_v = sess.run(total_sum)
    344         val = saver.save(sess, save_path)
    345         self.assertEqual(save_path, val)
    346       # Passing graph explicitly, otherwise an old sess would be reused.
    347       with self.test_session(
    348           use_gpu=True, graph=ops.get_default_graph()) as sess:
    349         reset_params = state_ops.assign(
    350             params,
    351             array_ops.zeros([params_size_t], dtype=dtype),
    352             validate_shape=False)
    353         sess.run(reset_params)
    354         saver.restore(sess, save_path)
    355         total_sum_v_restored = sess.run(total_sum)
    356         self.assertAllClose(total_sum_v, total_sum_v_restored, atol=1e-5)
    357 
    358   @unittest.skipUnless(test.is_built_with_cuda(),
    359                        "Test only applicable when running on GPUs")
    360   def testSaveRestore(self):
    361     rnn_modes = [
    362         cudnn_rnn_ops.CUDNN_LSTM, cudnn_rnn_ops.CUDNN_GRU,
    363         cudnn_rnn_ops.CUDNN_RNN_TANH, cudnn_rnn_ops.CUDNN_RNN_RELU
    364     ]
    365     directions = [
    366         cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION,
    367         cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION
    368     ]
    369     dtype_list = [dtypes.float32, dtypes.float64]
    370     for rnn_mode, direction, dtype in itertools.product(rnn_modes, directions,
    371                                                         dtype_list):
    372       self._testSaveRestoreVariable(rnn_mode, direction, dtype)
    373       self._testSaveRestoreTwoVariables(rnn_mode, direction, dtype)
    374       self._testSaveRestoreOutput(rnn_mode, direction, dtype)
    375 
    376 
    377 class CudnnRNNTestParamsSize(TensorFlowTestCase):
    378 
    379   def _testOneLSTMParamsSize(self, num_layers, num_units, input_size,
    380                              direction):
    381     logging.info("Testing one lstm param size with config: %s", locals())
    382     min_params_size = _MinLSTMParamSize(num_layers, num_units, input_size,
    383                                         direction)
    384     model = _CreateModel(
    385         cudnn_rnn_ops.CUDNN_LSTM,
    386         num_layers,
    387         num_units,
    388         input_size,
    389         direction=direction)
    390     params_size = model.params_size()
    391     with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess:
    392       params_size_v = sess.run(params_size)
    393       self.assertLessEqual(min_params_size, params_size_v)
    394 
    395   @unittest.skipUnless(test.is_built_with_cuda(),
    396                        "Test only applicable when running on GPUs")
    397   def testLSTMParamsSize(self):
    398     test_configs = [
    399         [4, 200, 200],
    400         [4, 200, 300],
    401         [4, 200, 100],
    402         [1, 100, 200],
    403         [2, 200, 100],
    404         [3, 200, 400],
    405     ]
    406     directions = [
    407         cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION,
    408         cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION
    409     ]
    410     for (config, direction) in itertools.product(test_configs, directions):
    411       num_layers, num_units, input_size = config
    412       with ops.Graph().as_default():
    413         self._testOneLSTMParamsSize(num_layers, num_units, input_size,
    414                                     direction)
    415 
    416 
    417 class CudnnRNNTestInference(TensorFlowTestCase):
    418 
    419   def _testOneSimpleInference(self, rnn_mode, num_layers, num_units, input_size,
    420                               batch_size, seq_length, dir_count, dropout,
    421                               expected, tolerance):
    422     random_seed.set_random_seed(5678)
    423     model = _CreateModel(
    424         rnn_mode,
    425         num_layers,
    426         num_units,
    427         input_size,
    428         input_mode="auto_select",
    429         direction=(cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION if dir_count == 1
    430                    else cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION),
    431         dropout=dropout)
    432     has_input_c = (rnn_mode == cudnn_rnn_ops.CUDNN_LSTM)
    433     params_size_t = model.params_size()
    434     input_data = array_ops.ones([seq_length, batch_size, input_size])
    435     input_h = array_ops.ones([num_layers * dir_count, batch_size, num_units])
    436     params = variables.Variable(
    437         array_ops.ones([params_size_t]), validate_shape=False)
    438     if has_input_c:
    439       input_c = array_ops.ones([num_layers * dir_count, batch_size, num_units])
    440       output, output_h, output_c = model(
    441           input_data=input_data,
    442           input_h=input_h,
    443           input_c=input_c,
    444           params=params,
    445           is_training=False)
    446     else:
    447       output, output_h = model(
    448           input_data=input_data,
    449           input_h=input_h,
    450           params=params,
    451           is_training=False)
    452     output_sum = math_ops.reduce_sum(output)
    453     output_h_sum = math_ops.reduce_sum(output_h)
    454     total_sum = output_sum + output_h_sum
    455     if has_input_c:
    456       output_c_sum = math_ops.reduce_sum(output_c)
    457       total_sum += output_c_sum
    458     with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess:
    459       sess.run(variables.global_variables_initializer())
    460       total_sum_v = sess.run([total_sum])
    461 
    462       self.assertAllClose(
    463           total_sum_v[0], expected, atol=tolerance, rtol=tolerance)
    464 
    465   @unittest.skipUnless(test.is_built_with_cuda(),
    466                        "Test only applicable when running on GPUs")
    467   def testSimpleInference(self):
    468     test_configs = [
    469         {
    470             "rnn_mode": cudnn_rnn_ops.CUDNN_LSTM,
    471             "expected": 231833.22,
    472             "tolerance": 1e-2,
    473             "shape": {
    474                 "num_layers": 4,
    475                 "num_units": 200,
    476                 "input_size": 200,
    477                 "batch_size": 20,
    478                 "seq_length": 10,
    479                 "dir_count": 1,
    480             },
    481         },
    482         {
    483             "rnn_mode": cudnn_rnn_ops.CUDNN_GRU,
    484             "expected": 56000,
    485             "tolerance": 1e-2,
    486             "shape": {
    487                 "num_layers": 4,
    488                 "num_units": 200,
    489                 "input_size": 200,
    490                 "batch_size": 20,
    491                 "seq_length": 10,
    492                 "dir_count": 1,
    493             },
    494         },
    495         {
    496             "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_TANH,
    497             "expected": 56000,
    498             "tolerance": 1e-2,
    499             "shape": {
    500                 "num_layers": 4,
    501                 "num_units": 200,
    502                 "input_size": 200,
    503                 "batch_size": 20,
    504                 "seq_length": 10,
    505                 "dir_count": 1,
    506             },
    507         },
    508         {
    509             "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_RELU,
    510             "expected": 130688,
    511             "tolerance": 1e-2,
    512             "shape": {
    513                 "num_layers": 2,
    514                 "num_units": 8,
    515                 "input_size": 4,
    516                 "batch_size": 4,
    517                 "seq_length": 2,
    518                 "dir_count": 1,
    519             },
    520         },
    521     ]
    522     # Cudnn scales result for dropout during training, therefore dropout has no
    523     # impact for inference results.
    524     # (lstm, gru, rnn_tanh are saturated in the test. rnn_relu case is most
    525     # demonstrative of the dropout-invariant nature of CudnnRnn.)
    526     dropouts = [0., 0.5, 1.]
    527     for (config, dropout) in itertools.product(test_configs, dropouts):
    528       rnn_mode = config["rnn_mode"]
    529       expected = config["expected"]
    530       tolerance = config["tolerance"]
    531       shape = config["shape"]
    532       with ops.Graph().as_default():
    533         self._testOneSimpleInference(
    534             rnn_mode, shape["num_layers"], shape["num_units"],
    535             shape["input_size"], shape["batch_size"], shape["seq_length"],
    536             shape["dir_count"], dropout, expected, tolerance)
    537 
    538 
    539 class CudnnRNNTestTraining(TensorFlowTestCase):
    540 
    541   def _testOneSimpleTraining(self, rnn_mode, num_layers, num_units, input_size,
    542                              batch_size, seq_length, dir_count, dropout, dtype,
    543                              delta, tolerance):
    544     # Gradient checking runs two forward ops with almost the same input. Need to
    545     # make sure the drop patterns across the two runs are the same.
    546     logging.info("Training test with config: %s", locals())
    547     old_env_state = os.environ.get("TF_CUDNN_RESET_RND_GEN_STATE", str(False))
    548     os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = str(True)
    549     has_input_c = (rnn_mode == cudnn_rnn_ops.CUDNN_LSTM)
    550     random_seed.set_random_seed(5678)
    551     direction = (cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION if dir_count == 1
    552                  else cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION)
    553     model = _CreateModel(
    554         rnn_mode,
    555         num_layers,
    556         num_units,
    557         input_size,
    558         direction=direction,
    559         dtype=dtype,
    560         dropout=dropout)
    561     params_size_t = model.params_size()
    562     input_data = variables.Variable(
    563         random_ops.random_uniform(
    564             [seq_length, batch_size, input_size], dtype=dtype),
    565         dtype=dtype)
    566     input_h = variables.Variable(
    567         random_ops.random_uniform(
    568             [num_layers * dir_count, batch_size, num_units], dtype=dtype),
    569         dtype=dtype)
    570     params = variables.Variable(
    571         random_ops.random_uniform([params_size_t], dtype=dtype),
    572         validate_shape=False,
    573         dtype=dtype)
    574     if has_input_c:
    575       input_c = variables.Variable(
    576           random_ops.random_uniform(
    577               [num_layers * dir_count, batch_size, num_units], dtype=dtype),
    578           dtype=dtype)
    579 
    580       output, output_h, output_c = model(
    581           input_data=input_data,
    582           input_h=input_h,
    583           input_c=input_c,
    584           params=params)
    585     else:
    586       output, output_h = model(
    587           input_data=input_data, input_h=input_h, params=params)
    588     output_sum = math_ops.reduce_sum(output)
    589     output_h_sum = math_ops.reduce_sum(output_h)
    590     total_sum = output_sum + output_h_sum
    591     if has_input_c:
    592       output_c_sum = math_ops.reduce_sum(output_c)
    593       total_sum += output_c_sum
    594 
    595     with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess:
    596       params_size_v = sess.run(params_size_t)
    597       inputs_and_shapes = [
    598           (input_data, [seq_length, batch_size, input_size]),
    599           (input_h, [num_layers * dir_count, batch_size, num_units]),
    600           (params, [params_size_v]),
    601       ]
    602       if has_input_c:
    603         inputs_and_shapes.append(
    604             (input_c, [num_layers * dir_count, batch_size, num_units]),)
    605       sess.run(variables.global_variables_initializer())
    606       all_inputs = [entry[0] for entry in inputs_and_shapes]
    607       all_shapes = [entry[1] for entry in inputs_and_shapes]
    608 
    609       err = gradient_checker.compute_gradient_error(
    610           all_inputs, all_shapes, total_sum, [1], delta=delta)
    611 
    612       self.assertLess(err, tolerance)
    613       os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = old_env_state
    614 
    615   @unittest.skipUnless(test.is_built_with_cuda(),
    616                        "Test only applicable when running on GPUs")
    617   def testSimpleTraining(self):
    618     test_configs = [
    619         {
    620             "rnn_mode": cudnn_rnn_ops.CUDNN_LSTM,
    621             "dtype": dtypes.float64,
    622             "delta": 1e-4,
    623             "tolerance": 5e-6,
    624             "shape": {
    625                 "num_layers": 2,
    626                 "num_units": 3,
    627                 "input_size": 4,
    628                 "batch_size": 3,
    629                 "seq_length": 4,
    630                 "dir_count": 1,
    631             },
    632         },
    633         {
    634             "rnn_mode": cudnn_rnn_ops.CUDNN_GRU,
    635             "dtype": dtypes.float64,
    636             "delta": 1e-4,
    637             "tolerance": 5e-6,
    638             "shape": {
    639                 "num_layers": 2,
    640                 "num_units": 3,
    641                 "input_size": 4,
    642                 "batch_size": 3,
    643                 "seq_length": 4,
    644                 "dir_count": 1,
    645             },
    646         },
    647         {
    648             "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_TANH,
    649             "dtype": dtypes.float64,
    650             "delta": 1e-4,
    651             "tolerance": 5e-6,
    652             "shape": {
    653                 "num_layers": 2,
    654                 "num_units": 3,
    655                 "input_size": 4,
    656                 "batch_size": 3,
    657                 "seq_length": 4,
    658                 "dir_count": 1,
    659             },
    660         },
    661         {
    662             "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_RELU,
    663             "dtype": dtypes.float64,
    664             "delta": 1e-4,
    665             "tolerance": 5e-6,
    666             "shape": {
    667                 "num_layers": 2,
    668                 "num_units": 3,
    669                 "input_size": 4,
    670                 "batch_size": 3,
    671                 "seq_length": 4,
    672                 "dir_count": 1,
    673             },
    674         },
    675         {
    676             "rnn_mode": cudnn_rnn_ops.CUDNN_LSTM,
    677             "dtype": dtypes.float32,
    678             "tolerance": 1.5e-2,
    679             "shape": {
    680                 "num_layers": 2,
    681                 "num_units": 3,
    682                 "input_size": 4,
    683                 "batch_size": 3,
    684                 "seq_length": 4,
    685             },
    686         },
    687         {
    688             "rnn_mode": cudnn_rnn_ops.CUDNN_GRU,
    689             "dtype": dtypes.float32,
    690             "tolerance": 4e-3,
    691             "shape": {
    692                 "num_layers": 2,
    693                 "num_units": 3,
    694                 "input_size": 4,
    695                 "batch_size": 3,
    696                 "seq_length": 4,
    697             },
    698         },
    699         {
    700             "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_TANH,
    701             "dtype": dtypes.float32,
    702             "tolerance": 5e-3,
    703             "shape": {
    704                 "num_layers": 2,
    705                 "num_units": 3,
    706                 "input_size": 4,
    707                 "batch_size": 3,
    708                 "seq_length": 4,
    709             },
    710         },
    711         {
    712             "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_RELU,
    713             "dtype": dtypes.float32,
    714             "tolerance": 5e-1,
    715             "shape": {
    716                 "num_layers": 2,
    717                 "num_units": 3,
    718                 "input_size": 4,
    719                 "batch_size": 3,
    720                 "seq_length": 4,
    721             },
    722         },
    723     ]
    724     dropouts = [0., 0.5, 1.]
    725     dir_counts = [1]
    726     for config, dropout, dir_count in itertools.product(test_configs, dropouts,
    727                                                         dir_counts):
    728       rnn_mode = config["rnn_mode"]
    729       dtype = config.get("dtype", dtypes.float32)
    730       delta = config.get("delta", 1e-3)
    731       tolerance = config["tolerance"]
    732       shape = config["shape"]
    733       with ops.Graph().as_default():
    734         self._testOneSimpleTraining(rnn_mode, shape["num_layers"],
    735                                     shape["num_units"], shape["input_size"],
    736                                     shape["batch_size"], shape["seq_length"],
    737                                     dir_count, dropout, dtype, delta, tolerance)
    738 
    739 
    740 if __name__ == "__main__":
    741   googletest.main()
    742