Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """LSTM Block Cell ops."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.contrib.rnn.python.kernel_tests import benchmarking
     24 from tensorflow.contrib.rnn.python.ops import lstm_ops
     25 from tensorflow.python.client import session
     26 from tensorflow.python.framework import constant_op
     27 from tensorflow.python.framework import dtypes
     28 from tensorflow.python.framework import ops
     29 from tensorflow.python.ops import array_ops
     30 from tensorflow.python.ops import gradients_impl
     31 from tensorflow.python.ops import init_ops
     32 from tensorflow.python.ops import rnn
     33 from tensorflow.python.ops import rnn_cell
     34 from tensorflow.python.ops import variable_scope
     35 from tensorflow.python.ops import variables
     36 from tensorflow.python.platform import test
     37 
     38 block_lstm = lstm_ops._block_lstm  # pylint: disable=protected-access
     39 
     40 
     41 def blocks_match(sess, use_peephole):
     42   batch_size = 2
     43   input_size = 3
     44   cell_size = 4
     45   sequence_length = 4
     46 
     47   inputs = []
     48   for _ in range(sequence_length):
     49     inp = ops.convert_to_tensor(
     50         np.random.randn(batch_size, input_size), dtype=dtypes.float32)
     51     inputs.append(inp)
     52   stacked_inputs = array_ops.stack(inputs)
     53 
     54   initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=19890212)
     55 
     56   with variable_scope.variable_scope("test", initializer=initializer):
     57     # magic naming so that the cells pick up these variables and resuse them
     58     if use_peephole:
     59       wci = variable_scope.get_variable(
     60           "rnn/lstm_cell/w_i_diag", shape=[cell_size], dtype=dtypes.float32)
     61       wcf = variable_scope.get_variable(
     62           "rnn/lstm_cell/w_f_diag", shape=[cell_size], dtype=dtypes.float32)
     63       wco = variable_scope.get_variable(
     64           "rnn/lstm_cell/w_o_diag", shape=[cell_size], dtype=dtypes.float32)
     65 
     66     w = variable_scope.get_variable(
     67         "rnn/lstm_cell/kernel",
     68         shape=[input_size + cell_size, cell_size * 4],
     69         dtype=dtypes.float32)
     70     b = variable_scope.get_variable(
     71         "rnn/lstm_cell/bias",
     72         shape=[cell_size * 4],
     73         dtype=dtypes.float32,
     74         initializer=init_ops.zeros_initializer())
     75 
     76     basic_cell = rnn_cell.LSTMCell(
     77         cell_size, use_peepholes=use_peephole, state_is_tuple=True, reuse=True)
     78     basic_outputs_op, basic_state_op = rnn.static_rnn(
     79         basic_cell, inputs, dtype=dtypes.float32)
     80 
     81     if use_peephole:
     82       _, _, _, _, _, _, block_outputs_op = block_lstm(
     83           ops.convert_to_tensor(sequence_length, dtype=dtypes.int64),
     84           inputs,
     85           w,
     86           b,
     87           wci=wci,
     88           wcf=wcf,
     89           wco=wco,
     90           cell_clip=0,
     91           use_peephole=True)
     92     else:
     93       _, _, _, _, _, _, block_outputs_op = block_lstm(
     94           ops.convert_to_tensor(sequence_length, dtype=dtypes.int64),
     95           inputs,
     96           w,
     97           b,
     98           cell_clip=0)
     99 
    100     fused_cell = lstm_ops.LSTMBlockFusedCell(
    101         cell_size, cell_clip=0, use_peephole=use_peephole, reuse=True,
    102         name="rnn/lstm_cell")
    103     fused_outputs_op, fused_state_op = fused_cell(
    104         stacked_inputs, dtype=dtypes.float32)
    105 
    106     sess.run([variables.global_variables_initializer()])
    107     basic_outputs, basic_state = sess.run([basic_outputs_op, basic_state_op[0]])
    108     basic_grads = sess.run(gradients_impl.gradients(basic_outputs_op, inputs))
    109     xs = [w, b]
    110     if use_peephole:
    111       xs += [wci, wcf, wco]
    112     basic_wgrads = sess.run(gradients_impl.gradients(basic_outputs_op, xs))
    113 
    114     block_outputs = sess.run(block_outputs_op)
    115     block_grads = sess.run(gradients_impl.gradients(block_outputs_op, inputs))
    116     block_wgrads = sess.run(gradients_impl.gradients(block_outputs_op, xs))
    117 
    118     xs = [w, b]
    119     if use_peephole:
    120       xs += [wci, wcf, wco]
    121     fused_outputs, fused_state = sess.run([fused_outputs_op, fused_state_op[0]])
    122     fused_grads = sess.run(gradients_impl.gradients(fused_outputs_op, inputs))
    123     fused_wgrads = sess.run(gradients_impl.gradients(fused_outputs_op, xs))
    124 
    125     return (basic_state, fused_state, basic_outputs, block_outputs,
    126             fused_outputs, basic_grads, block_grads, fused_grads, basic_wgrads,
    127             block_wgrads, fused_wgrads)
    128 
    129 
    130 class LSTMBlockCellTest(test.TestCase):
    131 
    132   def testNoneDimsWithDynamicRNN(self):
    133     with self.test_session(use_gpu=True, graph=ops.Graph()) as sess:
    134       batch_size = 4
    135       num_steps = 5
    136       input_dim = 6
    137       cell_size = 7
    138 
    139       cell = lstm_ops.LSTMBlockCell(cell_size)
    140       x = array_ops.placeholder(dtypes.float32, shape=(None, None, input_dim))
    141 
    142       output, _ = rnn.dynamic_rnn(
    143           cell, x, time_major=True, dtype=dtypes.float32)
    144       sess.run(variables.global_variables_initializer())
    145       feed = {}
    146       feed[x] = np.random.randn(num_steps, batch_size, input_dim)
    147       sess.run(output, feed)
    148 
    149   def testLSTMBlockCell(self):
    150     with self.test_session(use_gpu=True, graph=ops.Graph()) as sess:
    151       with variable_scope.variable_scope(
    152           "root", initializer=init_ops.constant_initializer(0.5)):
    153         x = array_ops.zeros([1, 2])
    154         m0 = array_ops.zeros([1, 2])
    155         m1 = array_ops.zeros([1, 2])
    156         m2 = array_ops.zeros([1, 2])
    157         m3 = array_ops.zeros([1, 2])
    158         g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
    159             [lstm_ops.LSTMBlockCell(2)
    160              for _ in range(2)], state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
    161         sess.run([variables.global_variables_initializer()])
    162         res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
    163             x.name: np.array([[1., 1.]]),
    164             m0.name: 0.1 * np.ones([1, 2]),
    165             m1.name: 0.1 * np.ones([1, 2]),
    166             m2.name: 0.1 * np.ones([1, 2]),
    167             m3.name: 0.1 * np.ones([1, 2])
    168         })
    169         self.assertEqual(len(res), 5)
    170         self.assertAllClose(res[0], [[0.24024698, 0.24024698]])
    171         # These numbers are from testBasicLSTMCell and only test c/h.
    172         self.assertAllClose(res[1], [[0.68967271, 0.68967271]])
    173         self.assertAllClose(res[2], [[0.44848421, 0.44848421]])
    174         self.assertAllClose(res[3], [[0.39897051, 0.39897051]])
    175         self.assertAllClose(res[4], [[0.24024698, 0.24024698]])
    176 
    177   def testCompatibleNames(self):
    178     with self.test_session(use_gpu=True, graph=ops.Graph()):
    179       cell = rnn_cell.LSTMCell(10)
    180       pcell = rnn_cell.LSTMCell(10, use_peepholes=True)
    181       inputs = [array_ops.zeros([4, 5])] * 6
    182       rnn.static_rnn(cell, inputs, dtype=dtypes.float32, scope="basic")
    183       rnn.static_rnn(pcell, inputs, dtype=dtypes.float32, scope="peephole")
    184       basic_names = {
    185           v.name: v.get_shape()
    186           for v in variables.trainable_variables()
    187       }
    188 
    189     with self.test_session(use_gpu=True, graph=ops.Graph()):
    190       cell = lstm_ops.LSTMBlockCell(10)
    191       pcell = lstm_ops.LSTMBlockCell(10, use_peephole=True)
    192       inputs = [array_ops.zeros([4, 5])] * 6
    193       rnn.static_rnn(cell, inputs, dtype=dtypes.float32, scope="basic")
    194       rnn.static_rnn(pcell, inputs, dtype=dtypes.float32, scope="peephole")
    195       block_names = {
    196           v.name: v.get_shape()
    197           for v in variables.trainable_variables()
    198       }
    199 
    200     with self.test_session(use_gpu=True, graph=ops.Graph()):
    201       cell = lstm_ops.LSTMBlockFusedCell(10)
    202       pcell = lstm_ops.LSTMBlockFusedCell(10, use_peephole=True)
    203       inputs = array_ops.stack([array_ops.zeros([4, 5])] * 6)
    204       cell(inputs, dtype=dtypes.float32, scope="basic/lstm_cell")
    205       pcell(inputs, dtype=dtypes.float32, scope="peephole/lstm_cell")
    206       fused_names = {
    207           v.name: v.get_shape()
    208           for v in variables.trainable_variables()
    209       }
    210 
    211     self.assertEqual(basic_names, block_names)
    212     self.assertEqual(basic_names, fused_names)
    213 
    214   def testLSTMBasicToBlockCell(self):
    215     with self.test_session(use_gpu=True) as sess:
    216       x = array_ops.zeros([1, 2])
    217       x_values = np.random.randn(1, 2)
    218 
    219       m0_val = 0.1 * np.ones([1, 2])
    220       m1_val = -0.1 * np.ones([1, 2])
    221       m2_val = -0.2 * np.ones([1, 2])
    222       m3_val = 0.2 * np.ones([1, 2])
    223 
    224       initializer = init_ops.random_uniform_initializer(
    225           -0.01, 0.01, seed=19890212)
    226       with variable_scope.variable_scope("basic", initializer=initializer):
    227         m0 = array_ops.zeros([1, 2])
    228         m1 = array_ops.zeros([1, 2])
    229         m2 = array_ops.zeros([1, 2])
    230         m3 = array_ops.zeros([1, 2])
    231         g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
    232             [rnn_cell.BasicLSTMCell(2, state_is_tuple=True) for _ in range(2)],
    233             state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
    234         sess.run([variables.global_variables_initializer()])
    235         basic_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
    236             x.name: x_values,
    237             m0.name: m0_val,
    238             m1.name: m1_val,
    239             m2.name: m2_val,
    240             m3.name: m3_val
    241         })
    242 
    243       with variable_scope.variable_scope("block", initializer=initializer):
    244         m0 = array_ops.zeros([1, 2])
    245         m1 = array_ops.zeros([1, 2])
    246         m2 = array_ops.zeros([1, 2])
    247         m3 = array_ops.zeros([1, 2])
    248         g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
    249             [lstm_ops.LSTMBlockCell(2)
    250              for _ in range(2)], state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
    251         sess.run([variables.global_variables_initializer()])
    252         block_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
    253             x.name: x_values,
    254             m0.name: m0_val,
    255             m1.name: m1_val,
    256             m2.name: m2_val,
    257             m3.name: m3_val
    258         })
    259 
    260       self.assertEqual(len(basic_res), len(block_res))
    261       for basic, block in zip(basic_res, block_res):
    262         self.assertAllClose(basic, block)
    263 
    264   def testLSTMBasicToBlockCellPeeping(self):
    265     with self.test_session(use_gpu=True) as sess:
    266       x = array_ops.zeros([1, 2])
    267       x_values = np.random.randn(1, 2)
    268 
    269       m0_val = 0.1 * np.ones([1, 2])
    270       m1_val = -0.1 * np.ones([1, 2])
    271       m2_val = -0.2 * np.ones([1, 2])
    272       m3_val = 0.2 * np.ones([1, 2])
    273 
    274       initializer = init_ops.random_uniform_initializer(
    275           -0.01, 0.01, seed=19890212)
    276       with variable_scope.variable_scope("basic", initializer=initializer):
    277         m0 = array_ops.zeros([1, 2])
    278         m1 = array_ops.zeros([1, 2])
    279         m2 = array_ops.zeros([1, 2])
    280         m3 = array_ops.zeros([1, 2])
    281         g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
    282             [
    283                 rnn_cell.LSTMCell(2, use_peepholes=True, state_is_tuple=True)
    284                 for _ in range(2)
    285             ],
    286             state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
    287         sess.run([variables.global_variables_initializer()])
    288         basic_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
    289             x.name: x_values,
    290             m0.name: m0_val,
    291             m1.name: m1_val,
    292             m2.name: m2_val,
    293             m3.name: m3_val
    294         })
    295 
    296       with variable_scope.variable_scope("block", initializer=initializer):
    297         m0 = array_ops.zeros([1, 2])
    298         m1 = array_ops.zeros([1, 2])
    299         m2 = array_ops.zeros([1, 2])
    300         m3 = array_ops.zeros([1, 2])
    301         g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
    302             [lstm_ops.LSTMBlockCell(2, use_peephole=True) for _ in range(2)],
    303             state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
    304         sess.run([variables.global_variables_initializer()])
    305         block_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
    306             x.name: x_values,
    307             m0.name: m0_val,
    308             m1.name: m1_val,
    309             m2.name: m2_val,
    310             m3.name: m3_val
    311         })
    312 
    313       self.assertEqual(len(basic_res), len(block_res))
    314       for basic, block in zip(basic_res, block_res):
    315         self.assertAllClose(basic, block)
    316 
    317   def testLSTMBasicToBlock(self):
    318     with self.test_session(use_gpu=True) as sess:
    319       (basic_state, fused_state, basic_outputs, block_outputs, fused_outputs,
    320        basic_grads, block_grads, fused_grads, basic_wgrads, block_wgrads,
    321        fused_wgrads) = blocks_match(
    322            sess, use_peephole=False)
    323 
    324       self.assertAllClose(basic_outputs, block_outputs)
    325       self.assertAllClose(basic_grads, block_grads)
    326       for basic, block in zip(basic_wgrads, block_wgrads):
    327         self.assertAllClose(basic, block, rtol=1e-6, atol=1e-6)
    328 
    329       self.assertAllClose(basic_outputs, fused_outputs)
    330       self.assertAllClose(basic_state, fused_state)
    331       self.assertAllClose(basic_grads, fused_grads)
    332       for basic, fused in zip(block_wgrads, fused_wgrads):
    333         self.assertAllClose(basic, fused, rtol=1e-6, atol=1e-6)
    334 
    335   def testLSTMBasicToBlockPeeping(self):
    336     with self.test_session(use_gpu=True) as sess:
    337       (basic_state, fused_state, basic_outputs, block_outputs, fused_outputs,
    338        basic_grads, block_grads, fused_grads, basic_wgrads, block_wgrads,
    339        fused_wgrads) = blocks_match(
    340            sess, use_peephole=True)
    341 
    342       self.assertAllClose(basic_outputs, block_outputs)
    343       self.assertAllClose(basic_grads, block_grads)
    344       for basic, block in zip(basic_wgrads, block_wgrads):
    345         self.assertAllClose(basic, block, rtol=1e-6, atol=1e-6)
    346 
    347       self.assertAllClose(basic_outputs, fused_outputs)
    348       self.assertAllClose(basic_state, fused_state)
    349       self.assertAllClose(basic_grads, fused_grads)
    350       for basic, fused in zip(block_wgrads, fused_wgrads):
    351         self.assertAllClose(basic, fused, rtol=1e-6, atol=1e-6)
    352 
    353   def testLSTMFusedSequenceLengths(self):
    354     """Verify proper support for sequence lengths in LSTMBlockFusedCell."""
    355     with self.test_session(use_gpu=True) as sess:
    356       batch_size = 3
    357       input_size = 4
    358       cell_size = 5
    359       max_sequence_length = 6
    360 
    361       inputs = []
    362       for _ in range(max_sequence_length):
    363         inp = ops.convert_to_tensor(
    364             np.random.randn(batch_size, input_size), dtype=dtypes.float32)
    365         inputs.append(inp)
    366       seq_lengths = constant_op.constant([3, 4, 5])
    367       cell_inputs = array_ops.stack(inputs)
    368 
    369       initializer = init_ops.random_uniform_initializer(
    370           -0.01, 0.01, seed=19890213)
    371 
    372       with variable_scope.variable_scope("lstm_cell", initializer=initializer):
    373         # magic naming so that the cells pick up these variables and reuse them
    374         variable_scope.get_variable(
    375             "kernel",
    376             shape=[input_size + cell_size, cell_size * 4],
    377             dtype=dtypes.float32)
    378 
    379         variable_scope.get_variable(
    380             "bias",
    381             shape=[cell_size * 4],
    382             dtype=dtypes.float32,
    383             initializer=init_ops.zeros_initializer())
    384 
    385       cell = lstm_ops.LSTMBlockFusedCell(
    386           cell_size, cell_clip=0, use_peephole=False, reuse=True,
    387           name="lstm_cell")
    388 
    389       fused_outputs_op, fused_state_op = cell(
    390           cell_inputs, dtype=dtypes.float32, sequence_length=seq_lengths)
    391 
    392       cell_vars = [
    393           v for v in variables.trainable_variables()
    394           if v.name.endswith("kernel") or v.name.endswith("bias")
    395       ]
    396 
    397       # Verify that state propagation works if we turn our sequence into
    398       # tiny (single-time) subsequences, i.e. unfuse the cell
    399       unfused_outputs_op = []
    400       state = None
    401       with variable_scope.variable_scope(
    402           variable_scope.get_variable_scope(), reuse=True):
    403         for i, inp in enumerate(inputs):
    404           lengths = [int(i < l) for l in seq_lengths.eval()]
    405           output, state = cell(
    406               array_ops.expand_dims(inp, 0),
    407               initial_state=state,
    408               dtype=dtypes.float32,
    409               sequence_length=lengths)
    410           unfused_outputs_op.append(output[0])
    411       unfused_outputs_op = array_ops.stack(unfused_outputs_op)
    412 
    413       sess.run([variables.global_variables_initializer()])
    414       unfused_outputs, unfused_state = sess.run([unfused_outputs_op, state[0]])
    415       unfused_grads = sess.run(
    416           gradients_impl.gradients(unfused_outputs_op, inputs))
    417       unfused_wgrads = sess.run(
    418           gradients_impl.gradients(unfused_outputs_op, cell_vars))
    419 
    420       fused_outputs, fused_state = sess.run(
    421           [fused_outputs_op, fused_state_op[0]])
    422       fused_grads = sess.run(gradients_impl.gradients(fused_outputs_op, inputs))
    423       fused_wgrads = sess.run(
    424           gradients_impl.gradients(fused_outputs_op, cell_vars))
    425 
    426       self.assertAllClose(fused_outputs, unfused_outputs)
    427       self.assertAllClose(fused_state, unfused_state)
    428       self.assertAllClose(fused_grads, unfused_grads)
    429       for fused, unfused in zip(fused_wgrads, unfused_wgrads):
    430         self.assertAllClose(fused, unfused, rtol=1e-6, atol=1e-6)
    431 
    432 #### Benchmarking.
    433 
    434 
    435 class BenchmarkLSTMBlock(test.Benchmark):
    436 
    437   def benchmarkLSTMBlockCellFpropWithDynamicRNN(self):
    438     print("BlockLSTMCell forward propagation via dynamic_rnn().")
    439     print("--------------------------------------------------------------")
    440     print("LSTMBlockCell Seconds per inference.")
    441     print("batch_size,cell_size,input_size,time_steps,use_gpu,wall_time")
    442     iters = 10
    443     for config in benchmarking.dict_product({
    444         "batch_size": [1, 8, 13, 32, 67, 128],
    445         "cell_size": [128, 250, 512, 650, 1024, 1350],
    446         "time_steps": [40],
    447         "use_gpu": [True, False]
    448     }):
    449       with ops.Graph().as_default():
    450         with benchmarking.device(use_gpu=config["use_gpu"]):
    451           inputs = variable_scope.get_variable(
    452               "x",
    453               [config["time_steps"], config["batch_size"], config["cell_size"]])
    454           cell = lstm_ops.LSTMBlockCell(config["cell_size"])
    455           outputs = rnn.dynamic_rnn(
    456               cell, inputs, time_major=True, dtype=dtypes.float32)
    457           init_op = variables.global_variables_initializer()
    458 
    459         with session.Session() as sess:
    460           sess.run(init_op)
    461           wall_time = benchmarking.seconds_per_run(outputs, sess, iters)
    462 
    463         # Print to stdout. If the TEST_REPORT_FILE_PREFIX environment variable
    464         # is set, this will produce a copy-paste-able CSV file.
    465         print(",".join(
    466             map(str, [
    467                 config["batch_size"], config["cell_size"], config["cell_size"],
    468                 config["time_steps"], config["use_gpu"], wall_time
    469             ])))
    470         benchmark_name_template = "_".join([
    471             "LSTMBlockCell_fprop", "BS%(batch_size)i", "CS%(cell_size)i",
    472             "IS%(cell_size)i", "TS%(time_steps)i", "gpu_%(use_gpu)s"
    473         ])
    474 
    475         self.report_benchmark(
    476             name=benchmark_name_template % config,
    477             iters=iters,
    478             wall_time=wall_time,
    479             extras=config)
    480 
    481   def benchmarkLSTMBlockCellBpropWithDynamicRNN(self):
    482     print("BlockLSTMCell backward propagation via dynamic_rnn().")
    483     print("--------------------------------------------------------------")
    484     print("LSTMBlockCell Seconds per inference.")
    485     print("batch_size,cell_size,input_size,time_steps,use_gpu,wall_time")
    486     iters = 10
    487     for config in benchmarking.dict_product({
    488         "batch_size": [1, 8, 13, 32, 67, 128],
    489         "cell_size": [128, 250, 512, 650, 1024, 1350],
    490         "time_steps": [40],
    491         "use_gpu": [True, False]
    492     }):
    493       with ops.Graph().as_default():
    494         with benchmarking.device(use_gpu=config["use_gpu"]):
    495           time_steps = config["time_steps"]
    496           batch_size = config["batch_size"]
    497           cell_size = input_size = config["cell_size"]
    498           inputs = variable_scope.get_variable(
    499               "x", [time_steps, batch_size, cell_size],
    500               trainable=False,
    501               dtype=dtypes.float32)
    502           with variable_scope.variable_scope(
    503               "rnn", reuse=variable_scope.AUTO_REUSE):
    504             w = variable_scope.get_variable(
    505                 "rnn/lstm_cell/kernel",
    506                 shape=[input_size + cell_size, cell_size * 4],
    507                 dtype=dtypes.float32)
    508             b = variable_scope.get_variable(
    509                 "rnn/lstm_cell/bias",
    510                 shape=[cell_size * 4],
    511                 dtype=dtypes.float32,
    512                 initializer=init_ops.zeros_initializer())
    513             cell = lstm_ops.LSTMBlockCell(cell_size)
    514             outputs = rnn.dynamic_rnn(
    515                 cell, inputs, time_major=True, dtype=dtypes.float32)
    516           grads = gradients_impl.gradients(outputs, [inputs, w, b])
    517           init_op = variables.global_variables_initializer()
    518 
    519         with session.Session() as sess:
    520           sess.run(init_op)
    521           wall_time = benchmarking.seconds_per_run(grads, sess, iters)
    522 
    523         # Print to stdout. If the TEST_REPORT_FILE_PREFIX environment variable
    524         # is set, this will produce a copy-paste-able CSV file.
    525         print(",".join(
    526             map(str, [
    527                 batch_size, cell_size, cell_size, time_steps, config["use_gpu"],
    528                 wall_time
    529             ])))
    530         benchmark_name_template = "_".join([
    531             "LSTMBlockCell_bprop", "BS%(batch_size)i", "CS%(cell_size)i",
    532             "IS%(cell_size)i", "TS%(time_steps)i", "gpu_%(use_gpu)s"
    533         ])
    534 
    535         self.report_benchmark(
    536             name=benchmark_name_template % config,
    537             iters=iters,
    538             wall_time=wall_time,
    539             extras=config)
    540 
    541 
    542 if __name__ == "__main__":
    543   test.main()
    544