Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for RNN cells."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import itertools
     22 
     23 import numpy as np
     24 
     25 from tensorflow.contrib.rnn.python.ops import rnn_cell as contrib_rnn_cell
     26 from tensorflow.core.protobuf import config_pb2
     27 from tensorflow.python.client import session
     28 from tensorflow.python.framework import constant_op
     29 from tensorflow.python.framework import dtypes
     30 from tensorflow.python.framework import ops
     31 from tensorflow.python.framework import random_seed
     32 from tensorflow.python.ops import array_ops
     33 from tensorflow.python.ops import control_flow_ops
     34 from tensorflow.python.ops import gradients_impl
     35 from tensorflow.python.ops import init_ops
     36 from tensorflow.python.ops import math_ops
     37 from tensorflow.python.ops import random_ops
     38 from tensorflow.python.ops import rnn
     39 from tensorflow.python.ops import rnn_cell
     40 from tensorflow.python.ops import rnn_cell_impl
     41 from tensorflow.python.ops import variable_scope
     42 from tensorflow.python.ops import variables
     43 from tensorflow.python.platform import test
     44 from tensorflow.python.util import nest
     45 
     46 
     47 class RNNCellTest(test.TestCase):
     48 
     49   def testCoupledInputForgetGateLSTMCell(self):
     50     with self.test_session() as sess:
     51       num_units = 2
     52       state_size = num_units * 2
     53       batch_size = 3
     54       input_size = 4
     55       expected_output = np.array(
     56           [[0.121753, 0.121753], [0.103349, 0.103349], [0.100178, 0.100178]],
     57           dtype=np.float32)
     58       expected_state = np.array(
     59           [[0.137523, 0.137523, 0.121753, 0.121753], [
     60               0.105450, 0.105450, 0.103349, 0.103349
     61           ], [0.100742, 0.100742, 0.100178, 0.100178]],
     62           dtype=np.float32)
     63       with variable_scope.variable_scope(
     64           "root", initializer=init_ops.constant_initializer(0.5)):
     65         x = array_ops.zeros([batch_size, input_size])
     66         m = array_ops.zeros([batch_size, state_size])
     67         output, state = contrib_rnn_cell.CoupledInputForgetGateLSTMCell(
     68             num_units=num_units, forget_bias=1.0, state_is_tuple=False)(x, m)
     69         sess.run([variables.global_variables_initializer()])
     70         res = sess.run(
     71             [output, state], {
     72                 x.name:
     73                     np.array([[1., 1., 1., 1.], [2., 2., 2., 2.],
     74                               [3., 3., 3., 3.]]),
     75                 m.name:
     76                     0.1 * np.ones((batch_size, state_size))
     77             })
     78         # This is a smoke test: Only making sure expected values didn't change.
     79         self.assertEqual(len(res), 2)
     80         self.assertAllClose(res[0], expected_output)
     81         self.assertAllClose(res[1], expected_state)
     82 
     83   def testTimeFreqLSTMCell(self):
     84     with self.test_session() as sess:
     85       num_units = 8
     86       state_size = num_units * 2
     87       batch_size = 3
     88       input_size = 4
     89       feature_size = 2
     90       frequency_skip = 1
     91       num_shifts = (input_size - feature_size) // frequency_skip + 1
     92       with variable_scope.variable_scope(
     93           "root", initializer=init_ops.constant_initializer(0.5)):
     94         x = array_ops.zeros([batch_size, input_size])
     95         m = array_ops.zeros([batch_size, state_size * num_shifts])
     96         output, state = contrib_rnn_cell.TimeFreqLSTMCell(
     97             num_units=num_units,
     98             feature_size=feature_size,
     99             frequency_skip=frequency_skip,
    100             forget_bias=1.0)(x, m)
    101         sess.run([variables.global_variables_initializer()])
    102         res = sess.run(
    103             [output, state], {
    104                 x.name:
    105                     np.array([[1., 1., 1., 1.], [2., 2., 2., 2.],
    106                               [3., 3., 3., 3.]]),
    107                 m.name:
    108                     0.1 * np.ones((batch_size, int(state_size * (num_shifts))))
    109             })
    110         self.assertEqual(len(res), 2)
    111         # The numbers in results were not calculated, this is mostly just a
    112         # smoke test.
    113         self.assertEqual(res[0].shape, (batch_size, num_units * num_shifts))
    114         self.assertEqual(res[1].shape, (batch_size, state_size * num_shifts))
    115         # Different inputs so different outputs and states
    116         for i in range(1, batch_size):
    117           self.assertTrue(
    118               float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) > 1e-6)
    119           self.assertTrue(
    120               float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6)
    121 
    122   def testGridLSTMCell(self):
    123     with self.test_session() as sess:
    124       num_units = 8
    125       batch_size = 3
    126       input_size = 4
    127       feature_size = 2
    128       frequency_skip = 1
    129       num_shifts = int((input_size - feature_size) / frequency_skip + 1)
    130       with variable_scope.variable_scope(
    131           "root", initializer=init_ops.constant_initializer(0.5)):
    132         cell = contrib_rnn_cell.GridLSTMCell(
    133             num_units=num_units,
    134             feature_size=feature_size,
    135             frequency_skip=frequency_skip,
    136             forget_bias=1.0,
    137             num_frequency_blocks=[num_shifts],
    138             couple_input_forget_gates=True,
    139             state_is_tuple=True)
    140         inputs = constant_op.constant(
    141             np.array(
    142                 [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
    143                 dtype=np.float32),
    144             dtype=dtypes.float32)
    145         state_value = constant_op.constant(
    146             0.1 * np.ones((batch_size, num_units), dtype=np.float32),
    147             dtype=dtypes.float32)
    148         init_state = cell.state_tuple_type(*(
    149             [state_value, state_value] * num_shifts))
    150         output, state = cell(inputs, init_state)
    151         sess.run([variables.global_variables_initializer()])
    152         res = sess.run([output, state])
    153         self.assertEqual(len(res), 2)
    154         # The numbers in results were not calculated, this is mostly just a
    155         # smoke test.
    156         self.assertEqual(res[0].shape, (batch_size, num_units * num_shifts * 2))
    157         for ss in res[1]:
    158           self.assertEqual(ss.shape, (batch_size, num_units))
    159         # Different inputs so different outputs and states
    160         for i in range(1, batch_size):
    161           self.assertTrue(
    162               float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) > 1e-6)
    163           self.assertTrue(
    164               float(
    165                   np.linalg.norm((res[1].state_f00_b00_c[0, :] - res[1]
    166                                   .state_f00_b00_c[i, :]))) > 1e-6)
    167 
    168   def testGridLSTMCellWithFrequencyBlocks(self):
    169     with self.test_session() as sess:
    170       num_units = 8
    171       batch_size = 3
    172       feature_size = 2
    173       frequency_skip = 1
    174       num_frequency_blocks = [1, 1]
    175       total_blocks = num_frequency_blocks[0] + num_frequency_blocks[1]
    176       start_freqindex_list = [0, 2]
    177       end_freqindex_list = [2, 4]
    178       with variable_scope.variable_scope(
    179           "root", initializer=init_ops.constant_initializer(0.5)):
    180         cell = contrib_rnn_cell.GridLSTMCell(
    181             num_units=num_units,
    182             feature_size=feature_size,
    183             frequency_skip=frequency_skip,
    184             forget_bias=1.0,
    185             num_frequency_blocks=num_frequency_blocks,
    186             start_freqindex_list=start_freqindex_list,
    187             end_freqindex_list=end_freqindex_list,
    188             couple_input_forget_gates=True,
    189             state_is_tuple=True)
    190         inputs = constant_op.constant(
    191             np.array(
    192                 [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
    193                 dtype=np.float32),
    194             dtype=dtypes.float32)
    195         state_value = constant_op.constant(
    196             0.1 * np.ones((batch_size, num_units), dtype=np.float32),
    197             dtype=dtypes.float32)
    198         init_state = cell.state_tuple_type(*(
    199             [state_value, state_value] * total_blocks))
    200         output, state = cell(inputs, init_state)
    201         sess.run([variables.global_variables_initializer()])
    202         res = sess.run([output, state])
    203         self.assertEqual(len(res), 2)
    204         # The numbers in results were not calculated, this is mostly just a
    205         # smoke test.
    206         self.assertEqual(res[0].shape,
    207                          (batch_size, num_units * total_blocks * 2))
    208         for ss in res[1]:
    209           self.assertEqual(ss.shape, (batch_size, num_units))
    210         # Different inputs so different outputs and states
    211         for i in range(1, batch_size):
    212           self.assertTrue(
    213               float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) > 1e-6)
    214           self.assertTrue(
    215               float(
    216                   np.linalg.norm((res[1].state_f00_b00_c[0, :] - res[1]
    217                                   .state_f00_b00_c[i, :]))) > 1e-6)
    218 
    219   def testGridLstmCellWithCoupledInputForgetGates(self):
    220     num_units = 2
    221     batch_size = 3
    222     input_size = 4
    223     feature_size = 2
    224     frequency_skip = 1
    225     num_shifts = int((input_size - feature_size) / frequency_skip + 1)
    226     expected_output = np.array(
    227         [[
    228             0.416383, 0.416383, 0.403238, 0.403238, 0.524020, 0.524020,
    229             0.565425, 0.565425, 0.557865, 0.557865, 0.609699, 0.609699
    230         ], [
    231             0.627331, 0.627331, 0.622393, 0.622393, 0.688342, 0.688342,
    232             0.708078, 0.708078, 0.694245, 0.694245, 0.715171, 0.715171
    233         ], [
    234             0.711050, 0.711050, 0.709197, 0.709197, 0.736533, 0.736533,
    235             0.744264, 0.744264, 0.737390, 0.737390, 0.745250, 0.745250
    236         ]],
    237         dtype=np.float32)
    238     expected_state = np.array(
    239         [[
    240             0.625556, 0.625556, 0.416383, 0.416383, 0.759134, 0.759134,
    241             0.524020, 0.524020, 0.798795, 0.798795, 0.557865, 0.557865
    242         ], [
    243             0.875488, 0.875488, 0.627331, 0.627331, 0.936432, 0.936432,
    244             0.688342, 0.688342, 0.941961, 0.941961, 0.694245, 0.694245
    245         ], [
    246             0.957327, 0.957327, 0.711050, 0.711050, 0.979522, 0.979522,
    247             0.736533, 0.736533, 0.980245, 0.980245, 0.737390, 0.737390
    248         ]],
    249         dtype=np.float32)
    250     for state_is_tuple in [False, True]:
    251       with self.test_session() as sess:
    252         with variable_scope.variable_scope(
    253             "state_is_tuple" + str(state_is_tuple),
    254             initializer=init_ops.constant_initializer(0.5)):
    255           cell = contrib_rnn_cell.GridLSTMCell(
    256               num_units=num_units,
    257               feature_size=feature_size,
    258               frequency_skip=frequency_skip,
    259               forget_bias=1.0,
    260               num_frequency_blocks=[num_shifts],
    261               couple_input_forget_gates=True,
    262               state_is_tuple=state_is_tuple)
    263           inputs = constant_op.constant(
    264               np.array(
    265                   [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
    266                   dtype=np.float32),
    267               dtype=dtypes.float32)
    268           if state_is_tuple:
    269             state_value = constant_op.constant(
    270                 0.1 * np.ones((batch_size, num_units), dtype=np.float32),
    271                 dtype=dtypes.float32)
    272             init_state = cell.state_tuple_type(*(
    273                 [state_value, state_value] * num_shifts))
    274           else:
    275             init_state = constant_op.constant(
    276                 0.1 * np.ones(
    277                     (batch_size, num_units * num_shifts * 2), dtype=np.float32),
    278                 dtype=dtypes.float32)
    279           output, state = cell(inputs, init_state)
    280           sess.run([variables.global_variables_initializer()])
    281           res = sess.run([output, state])
    282           # This is a smoke test: Only making sure expected values not change.
    283           self.assertEqual(len(res), 2)
    284           self.assertAllClose(res[0], expected_output)
    285           if not state_is_tuple:
    286             self.assertAllClose(res[1], expected_state)
    287           else:
    288             # There should be num_shifts * 2 states in the tuple.
    289             self.assertEqual(len(res[1]), num_shifts * 2)
    290             # Checking the shape of each state to be batch_size * num_units
    291             for ss in res[1]:
    292               self.assertEqual(ss.shape[0], batch_size)
    293               self.assertEqual(ss.shape[1], num_units)
    294             self.assertAllClose(np.concatenate(res[1], axis=1), expected_state)
    295 
    296   def testBidirectionGridLSTMCell(self):
    297     with self.test_session() as sess:
    298       num_units = 2
    299       batch_size = 3
    300       input_size = 4
    301       feature_size = 2
    302       frequency_skip = 1
    303       num_shifts = int((input_size - feature_size) / frequency_skip + 1)
    304       expected_output = np.array(
    305           [[
    306               0.464130, 0.464130, 0.419165, 0.419165, 0.593283, 0.593283,
    307               0.738350, 0.738350, 0.661638, 0.661638, 0.866774, 0.866774,
    308               0.520789, 0.520789, 0.476968, 0.476968, 0.604341, 0.604341,
    309               0.760207, 0.760207, 0.635773, 0.635773, 0.850218, 0.850218
    310           ], [
    311               0.669636, 0.669636, 0.628966, 0.628966, 0.736057, 0.736057,
    312               0.895927, 0.895927, 0.755559, 0.755559, 0.954359, 0.954359,
    313               0.692621, 0.692621, 0.652363, 0.652363, 0.737517, 0.737517,
    314               0.899558, 0.899558, 0.745984, 0.745984, 0.946840, 0.946840
    315           ], [
    316               0.751109, 0.751109, 0.711716, 0.711716, 0.778357, 0.778357,
    317               0.940779, 0.940779, 0.784530, 0.784530, 0.980604, 0.980604,
    318               0.759940, 0.759940, 0.720652, 0.720652, 0.778552, 0.778552,
    319               0.941606, 0.941606, 0.781035, 0.781035, 0.977731, 0.977731
    320           ]],
    321           dtype=np.float32)
    322       expected_state = np.array(
    323           [[
    324               0.710660, 0.710660, 0.464130, 0.464130, 0.877293, 0.877293,
    325               0.593283, 0.593283, 0.958505, 0.958505, 0.661638, 0.661638,
    326               0.785405, 0.785405, 0.520789, 0.520789, 0.890836, 0.890836,
    327               0.604341, 0.604341, 0.928512, 0.928512, 0.635773, 0.635773
    328           ], [
    329               0.967579, 0.967579, 0.669636, 0.669636, 1.038811, 1.038811,
    330               0.736057, 0.736057, 1.058201, 1.058201, 0.755559, 0.755559,
    331               0.993088, 0.993088, 0.692621, 0.692621, 1.040288, 1.040288,
    332               0.737517, 0.737517, 1.048773, 1.048773, 0.745984, 0.745984
    333           ], [
    334               1.053842, 1.053842, 0.751109, 0.751109, 1.079919, 1.079919,
    335               0.778357, 0.778357, 1.085620, 1.085620, 0.784530, 0.784530,
    336               1.062455, 1.062455, 0.759940, 0.759940, 1.080101, 1.080101,
    337               0.778552, 0.778552, 1.082402, 1.082402, 0.781035, 0.781035
    338           ]],
    339           dtype=np.float32)
    340       with variable_scope.variable_scope(
    341           "root", initializer=init_ops.constant_initializer(0.5)):
    342         cell = contrib_rnn_cell.BidirectionalGridLSTMCell(
    343             num_units=num_units,
    344             feature_size=feature_size,
    345             share_time_frequency_weights=True,
    346             frequency_skip=frequency_skip,
    347             forget_bias=1.0,
    348             num_frequency_blocks=[num_shifts])
    349         inputs = constant_op.constant(
    350             np.array(
    351                 [[1.0, 1.1, 1.2, 1.3], [2.0, 2.1, 2.2, 2.3],
    352                  [3.0, 3.1, 3.2, 3.3]],
    353                 dtype=np.float32),
    354             dtype=dtypes.float32)
    355         state_value = constant_op.constant(
    356             0.1 * np.ones((batch_size, num_units), dtype=np.float32),
    357             dtype=dtypes.float32)
    358         init_state = cell.state_tuple_type(*(
    359             [state_value, state_value] * num_shifts * 2))
    360         output, state = cell(inputs, init_state)
    361         sess.run([variables.global_variables_initializer()])
    362         res = sess.run([output, state])
    363         self.assertEqual(len(res), 2)
    364         # The numbers in results were not calculated, this is mostly just a
    365         # smoke test.
    366         self.assertEqual(res[0].shape, (batch_size, num_units * num_shifts * 4))
    367         self.assertAllClose(res[0], expected_output)
    368         # There should be num_shifts * 4 states in the tuple.
    369         self.assertEqual(len(res[1]), num_shifts * 4)
    370         # Checking the shape of each state to be batch_size * num_units
    371         for ss in res[1]:
    372           self.assertEqual(ss.shape[0], batch_size)
    373           self.assertEqual(ss.shape[1], num_units)
    374         self.assertAllClose(np.concatenate(res[1], axis=1), expected_state)
    375 
    376   def testBidirectionGridLSTMCellWithSliceOffset(self):
    377     with self.test_session() as sess:
    378       num_units = 2
    379       batch_size = 3
    380       input_size = 4
    381       feature_size = 2
    382       frequency_skip = 1
    383       num_shifts = int((input_size - feature_size) / frequency_skip + 1)
    384       expected_output = np.array(
    385           [[
    386               0.464130, 0.464130, 0.419165, 0.419165, 0.593283, 0.593283,
    387               0.738350, 0.738350, 0.661638, 0.661638, 0.866774, 0.866774,
    388               0.322645, 0.322645, 0.276068, 0.276068, 0.584654, 0.584654,
    389               0.690292, 0.690292, 0.640446, 0.640446, 0.840071, 0.840071
    390           ], [
    391               0.669636, 0.669636, 0.628966, 0.628966, 0.736057, 0.736057,
    392               0.895927, 0.895927, 0.755559, 0.755559, 0.954359, 0.954359,
    393               0.493625, 0.493625, 0.449236, 0.449236, 0.730828, 0.730828,
    394               0.865996, 0.865996, 0.749429, 0.749429, 0.944958, 0.944958
    395           ], [
    396               0.751109, 0.751109, 0.711716, 0.711716, 0.778357, 0.778357,
    397               0.940779, 0.940779, 0.784530, 0.784530, 0.980604, 0.980604,
    398               0.608587, 0.608587, 0.566683, 0.566683, 0.777345, 0.777345,
    399               0.925820, 0.925820, 0.782597, 0.782597, 0.976858, 0.976858
    400           ]],
    401           dtype=np.float32)
    402       expected_state = np.array(
    403           [[
    404               0.710660, 0.710660, 0.464130, 0.464130, 0.877293, 0.877293,
    405               0.593283, 0.593283, 0.958505, 0.958505, 0.661638, 0.661638,
    406               0.516575, 0.516575, 0.322645, 0.322645, 0.866628, 0.866628,
    407               0.584654, 0.584654, 0.934002, 0.934002, 0.640446, 0.640446
    408           ], [
    409               0.967579, 0.967579, 0.669636, 0.669636, 1.038811, 1.038811,
    410               0.736057, 0.736057, 1.058201, 1.058201, 0.755559, 0.755559,
    411               0.749836, 0.749836, 0.493625, 0.493625, 1.033488, 1.033488,
    412               0.730828, 0.730828, 1.052186, 1.052186, 0.749429, 0.749429
    413           ], [
    414               1.053842, 1.053842, 0.751109, 0.751109, 1.079919, 1.079919,
    415               0.778357, 0.778357, 1.085620, 1.085620, 0.784530, 0.784530,
    416               0.895999, 0.895999, 0.608587, 0.608587, 1.078978, 1.078978,
    417               0.777345, 0.777345, 1.083843, 1.083843, 0.782597, 0.782597
    418           ]],
    419           dtype=np.float32)
    420       with variable_scope.variable_scope(
    421           "root", initializer=init_ops.constant_initializer(0.5)):
    422         cell = contrib_rnn_cell.BidirectionalGridLSTMCell(
    423             num_units=num_units,
    424             feature_size=feature_size,
    425             share_time_frequency_weights=True,
    426             frequency_skip=frequency_skip,
    427             forget_bias=1.0,
    428             num_frequency_blocks=[num_shifts],
    429             backward_slice_offset=1)
    430         inputs = constant_op.constant(
    431             np.array(
    432                 [[1.0, 1.1, 1.2, 1.3], [2.0, 2.1, 2.2, 2.3],
    433                  [3.0, 3.1, 3.2, 3.3]],
    434                 dtype=np.float32),
    435             dtype=dtypes.float32)
    436         state_value = constant_op.constant(
    437             0.1 * np.ones((batch_size, num_units), dtype=np.float32),
    438             dtype=dtypes.float32)
    439         init_state = cell.state_tuple_type(*(
    440             [state_value, state_value] * num_shifts * 2))
    441         output, state = cell(inputs, init_state)
    442         sess.run([variables.global_variables_initializer()])
    443         res = sess.run([output, state])
    444         self.assertEqual(len(res), 2)
    445         # The numbers in results were not calculated, this is mostly just a
    446         # smoke test.
    447         self.assertEqual(res[0].shape, (batch_size, num_units * num_shifts * 4))
    448         self.assertAllClose(res[0], expected_output)
    449         # There should be num_shifts * 4 states in the tuple.
    450         self.assertEqual(len(res[1]), num_shifts * 4)
    451         # Checking the shape of each state to be batch_size * num_units
    452         for ss in res[1]:
    453           self.assertEqual(ss.shape[0], batch_size)
    454           self.assertEqual(ss.shape[1], num_units)
    455         self.assertAllClose(np.concatenate(res[1], axis=1), expected_state)
    456 
    457   def testAttentionCellWrapperFailures(self):
    458     with self.assertRaisesRegexp(TypeError,
    459                                  "The parameter cell is not RNNCell."):
    460       contrib_rnn_cell.AttentionCellWrapper(None, 0)
    461 
    462     num_units = 8
    463     for state_is_tuple in [False, True]:
    464       with ops.Graph().as_default():
    465         lstm_cell = rnn_cell.BasicLSTMCell(
    466             num_units, state_is_tuple=state_is_tuple)
    467         with self.assertRaisesRegexp(
    468             ValueError, "attn_length should be greater than zero, got 0"):
    469           contrib_rnn_cell.AttentionCellWrapper(
    470               lstm_cell, 0, state_is_tuple=state_is_tuple)
    471         with self.assertRaisesRegexp(
    472             ValueError, "attn_length should be greater than zero, got -1"):
    473           contrib_rnn_cell.AttentionCellWrapper(
    474               lstm_cell, -1, state_is_tuple=state_is_tuple)
    475       with ops.Graph().as_default():
    476         lstm_cell = rnn_cell.BasicLSTMCell(num_units, state_is_tuple=True)
    477         with self.assertRaisesRegexp(
    478             ValueError, "Cell returns tuple of states, but the flag "
    479             "state_is_tuple is not set. State size is: *"):
    480           contrib_rnn_cell.AttentionCellWrapper(
    481               lstm_cell, 4, state_is_tuple=False)
    482 
    483   def testAttentionCellWrapperZeros(self):
    484     num_units = 8
    485     attn_length = 16
    486     batch_size = 3
    487     input_size = 4
    488     for state_is_tuple in [False, True]:
    489       with ops.Graph().as_default():
    490         with self.test_session() as sess:
    491           with variable_scope.variable_scope(
    492               "state_is_tuple_" + str(state_is_tuple)):
    493             lstm_cell = rnn_cell.BasicLSTMCell(
    494                 num_units, state_is_tuple=state_is_tuple)
    495             cell = contrib_rnn_cell.AttentionCellWrapper(
    496                 lstm_cell, attn_length, state_is_tuple=state_is_tuple)
    497             if state_is_tuple:
    498               zeros = array_ops.zeros([batch_size, num_units], dtype=np.float32)
    499               attn_state_zeros = array_ops.zeros(
    500                   [batch_size, attn_length * num_units], dtype=np.float32)
    501               zero_state = ((zeros, zeros), zeros, attn_state_zeros)
    502             else:
    503               zero_state = array_ops.zeros(
    504                   [
    505                       batch_size,
    506                       num_units * 2 + attn_length * num_units + num_units
    507                   ],
    508                   dtype=np.float32)
    509             inputs = array_ops.zeros(
    510                 [batch_size, input_size], dtype=dtypes.float32)
    511             output, state = cell(inputs, zero_state)
    512             self.assertEquals(output.get_shape(), [batch_size, num_units])
    513             if state_is_tuple:
    514               self.assertEquals(len(state), 3)
    515               self.assertEquals(len(state[0]), 2)
    516               self.assertEquals(state[0][0].get_shape(),
    517                                 [batch_size, num_units])
    518               self.assertEquals(state[0][1].get_shape(),
    519                                 [batch_size, num_units])
    520               self.assertEquals(state[1].get_shape(), [batch_size, num_units])
    521               self.assertEquals(state[2].get_shape(),
    522                                 [batch_size, attn_length * num_units])
    523               tensors = [output] + list(state)
    524             else:
    525               self.assertEquals(state.get_shape(), [
    526                   batch_size,
    527                   num_units * 2 + num_units + attn_length * num_units
    528               ])
    529               tensors = [output, state]
    530             zero_result = sum(
    531                 [math_ops.reduce_sum(math_ops.abs(x)) for x in tensors])
    532             sess.run(variables.global_variables_initializer())
    533             self.assertTrue(sess.run(zero_result) < 1e-6)
    534 
    535   def testAttentionCellWrapperValues(self):
    536     num_units = 8
    537     attn_length = 16
    538     batch_size = 3
    539     for state_is_tuple in [False, True]:
    540       with ops.Graph().as_default():
    541         with self.test_session() as sess:
    542           with variable_scope.variable_scope(
    543               "state_is_tuple_" + str(state_is_tuple)):
    544             lstm_cell = rnn_cell.BasicLSTMCell(
    545                 num_units, state_is_tuple=state_is_tuple)
    546             cell = contrib_rnn_cell.AttentionCellWrapper(
    547                 lstm_cell, attn_length, state_is_tuple=state_is_tuple)
    548             if state_is_tuple:
    549               zeros = constant_op.constant(
    550                   0.1 * np.ones([batch_size, num_units], dtype=np.float32),
    551                   dtype=dtypes.float32)
    552               attn_state_zeros = constant_op.constant(
    553                   0.1 * np.ones(
    554                       [batch_size, attn_length * num_units], dtype=np.float32),
    555                   dtype=dtypes.float32)
    556               zero_state = ((zeros, zeros), zeros, attn_state_zeros)
    557             else:
    558               zero_state = constant_op.constant(
    559                   0.1 * np.ones(
    560                       [
    561                           batch_size,
    562                           num_units * 2 + num_units + attn_length * num_units
    563                       ],
    564                       dtype=np.float32),
    565                   dtype=dtypes.float32)
    566             inputs = constant_op.constant(
    567                 np.array(
    568                     [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
    569                     dtype=np.float32),
    570                 dtype=dtypes.float32)
    571             output, state = cell(inputs, zero_state)
    572             if state_is_tuple:
    573               concat_state = array_ops.concat(
    574                   [state[0][0], state[0][1], state[1], state[2]], 1)
    575             else:
    576               concat_state = state
    577             sess.run(variables.global_variables_initializer())
    578             output, state = sess.run([output, concat_state])
    579             # Different inputs so different outputs and states
    580             for i in range(1, batch_size):
    581               self.assertTrue(
    582                   float(np.linalg.norm((output[0, :] - output[i, :]))) > 1e-6)
    583               self.assertTrue(
    584                   float(np.linalg.norm((state[0, :] - state[i, :]))) > 1e-6)
    585 
    586   def _testAttentionCellWrapperCorrectResult(self):
    587     num_units = 4
    588     attn_length = 6
    589     batch_size = 2
    590     expected_output = np.array(
    591         [[1.068372, 0.45496, -0.678277, 0.340538],
    592          [1.018088, 0.378983, -0.572179, 0.268591]],
    593         dtype=np.float32)
    594     expected_state = np.array(
    595         [[
    596             0.74946702, 0.34681597, 0.26474735, 1.06485605, 0.38465962,
    597             0.11420801, 0.10272158, 0.30925757, 0.63899988, 0.7181077,
    598             0.47534478, 0.33715725, 0.58086717, 0.49446869, 0.7641536,
    599             0.12814975, 0.92231739, 0.89857256, 0.21889746, 0.38442063,
    600             0.53481543, 0.8876909, 0.45823169, 0.5905602, 0.78038228,
    601             0.56501579, 0.03971386, 0.09870267, 0.8074435, 0.66821432,
    602             0.99211812, 0.12295902, 1.14606023, 0.34370938, -0.79251152,
    603             0.51843399
    604         ], [
    605             0.5179342, 0.48682183, -0.25426468, 0.96810579, 0.28809637,
    606             0.13607743, -0.11446252, 0.26792109, 0.78047138, 0.63460857,
    607             0.49122369, 0.52007174, 0.73000264, 0.66986895, 0.73576689,
    608             0.86301267, 0.87887371, 0.35185754, 0.93417215, 0.64732957,
    609             0.63173044, 0.66627824, 0.53644657, 0.20477486, 0.98458421,
    610             0.38277245, 0.03746676, 0.92510188, 0.57714164, 0.84932971,
    611             0.36127412, 0.12125921, 1.1362772, 0.34361625, -0.78150457,
    612             0.70582712
    613         ]],
    614         dtype=np.float32)
    615     seed = 12345
    616     random_seed.set_random_seed(seed)
    617     rnn_scope = None
    618     for state_is_tuple in [False, True]:
    619       with session.Session() as sess:
    620         with variable_scope.variable_scope(
    621             "state_is_tuple",
    622             reuse=state_is_tuple,
    623             initializer=init_ops.glorot_uniform_initializer()):
    624           lstm_cell = rnn_cell.BasicLSTMCell(
    625               num_units, state_is_tuple=state_is_tuple)
    626           cell = contrib_rnn_cell.AttentionCellWrapper(
    627               lstm_cell, attn_length, state_is_tuple=state_is_tuple)
    628           # This is legacy behavior to preserve the test.  Weight
    629           # sharing no longer works by creating a new RNNCell in the
    630           # same variable scope; so here we restore the scope of the
    631           # RNNCells after the first use below.
    632           if rnn_scope is not None:
    633             (cell._scope, lstm_cell._scope) = rnn_scope  # pylint: disable=protected-access,unpacking-non-sequence
    634           zeros1 = random_ops.random_uniform(
    635               (batch_size, num_units), 0.0, 1.0, seed=seed + 1)
    636           zeros2 = random_ops.random_uniform(
    637               (batch_size, num_units), 0.0, 1.0, seed=seed + 2)
    638           zeros3 = random_ops.random_uniform(
    639               (batch_size, num_units), 0.0, 1.0, seed=seed + 3)
    640           attn_state_zeros = random_ops.random_uniform(
    641               (batch_size, attn_length * num_units), 0.0, 1.0, seed=seed + 4)
    642           zero_state = ((zeros1, zeros2), zeros3, attn_state_zeros)
    643           if not state_is_tuple:
    644             zero_state = array_ops.concat([
    645                 zero_state[0][0], zero_state[0][1], zero_state[1], zero_state[2]
    646             ], 1)
    647           inputs = random_ops.random_uniform(
    648               (batch_size, num_units), 0.0, 1.0, seed=seed + 5)
    649           output, state = cell(inputs, zero_state)
    650           # This is legacy behavior to preserve the test.  Weight
    651           # sharing no longer works by creating a new RNNCell in the
    652           # same variable scope; so here we store the scope of the
    653           # first RNNCell for reuse above.
    654           if rnn_scope is None:
    655             rnn_scope = (cell._scope, lstm_cell._scope)  # pylint: disable=protected-access
    656           if state_is_tuple:
    657             state = array_ops.concat(
    658                 [state[0][0], state[0][1], state[1], state[2]], 1)
    659           sess.run(variables.global_variables_initializer())
    660           self.assertAllClose(sess.run(output), expected_output)
    661           self.assertAllClose(sess.run(state), expected_state)
    662 
    663   def testNASCell(self):
    664     num_units = 6
    665     batch_size = 3
    666     expected_output = np.array(
    667         [[0.576751, 0.576751, 0.576751, 0.576751, 0.576751, 0.576751],
    668          [0.618936, 0.618936, 0.618936, 0.618936, 0.618936, 0.618936],
    669          [0.627393, 0.627393, 0.627393, 0.627393, 0.627393, 0.627393]])
    670     expected_state = np.array([[
    671         0.71579772, 0.71579772, 0.71579772, 0.71579772, 0.71579772, 0.71579772,
    672         0.57675087, 0.57675087, 0.57675087, 0.57675087, 0.57675087, 0.57675087
    673     ], [
    674         0.78041625, 0.78041625, 0.78041625, 0.78041625, 0.78041625, 0.78041625,
    675         0.6189357, 0.6189357, 0.61893570, 0.6189357, 0.6189357, 0.6189357
    676     ], [
    677         0.79457647, 0.79457647, 0.79457647, 0.79457647, 0.79457653, 0.79457653,
    678         0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348
    679     ]])
    680     with self.test_session() as sess:
    681       with variable_scope.variable_scope(
    682           "nas_test", initializer=init_ops.constant_initializer(0.5)):
    683         cell = contrib_rnn_cell.NASCell(num_units=num_units)
    684         inputs = constant_op.constant(
    685             np.array(
    686                 [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
    687                 dtype=np.float32),
    688             dtype=dtypes.float32)
    689         state_value = constant_op.constant(
    690             0.1 * np.ones((batch_size, num_units), dtype=np.float32),
    691             dtype=dtypes.float32)
    692         init_state = rnn_cell.LSTMStateTuple(state_value, state_value)
    693         output, state = cell(inputs, init_state)
    694         sess.run([variables.global_variables_initializer()])
    695         res = sess.run([output, state])
    696 
    697         # This is a smoke test: Only making sure expected values not change.
    698         self.assertEqual(len(res), 2)
    699         self.assertAllClose(res[0], expected_output)
    700         # There should be 2 states in the tuple.
    701         self.assertEqual(len(res[1]), 2)
    702         # Checking the shape of each state to be batch_size * num_units
    703         new_c, new_h = res[1]
    704         self.assertEqual(new_c.shape[0], batch_size)
    705         self.assertEqual(new_c.shape[1], num_units)
    706         self.assertEqual(new_h.shape[0], batch_size)
    707         self.assertEqual(new_h.shape[1], num_units)
    708         self.assertAllClose(np.concatenate(res[1], axis=1), expected_state)
    709 
    710   def testNASCellProj(self):
    711     num_units = 6
    712     batch_size = 3
    713     num_proj = 5
    714     expected_output = np.array(
    715         [[1.697418, 1.697418, 1.697418, 1.697418,
    716           1.697418], [1.840037, 1.840037, 1.840037, 1.840037, 1.840037],
    717          [1.873985, 1.873985, 1.873985, 1.873985, 1.873985]])
    718     expected_state = np.array([[
    719         0.69855207, 0.69855207, 0.69855207, 0.69855207, 0.69855207, 0.69855207,
    720         1.69741797, 1.69741797, 1.69741797, 1.69741797, 1.69741797
    721     ], [
    722         0.77073824, 0.77073824, 0.77073824, 0.77073824, 0.77073824, 0.77073824,
    723         1.84003687, 1.84003687, 1.84003687, 1.84003687, 1.84003687
    724     ], [
    725         0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997,
    726         1.87398517, 1.87398517, 1.87398517, 1.87398517, 1.87398517
    727     ]])
    728     with self.test_session() as sess:
    729       with variable_scope.variable_scope(
    730           "nas_proj_test", initializer=init_ops.constant_initializer(0.5)):
    731         cell = contrib_rnn_cell.NASCell(num_units=num_units, num_proj=num_proj)
    732         inputs = constant_op.constant(
    733             np.array(
    734                 [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
    735                 dtype=np.float32),
    736             dtype=dtypes.float32)
    737         state_value_c = constant_op.constant(
    738             0.1 * np.ones((batch_size, num_units), dtype=np.float32),
    739             dtype=dtypes.float32)
    740         state_value_h = constant_op.constant(
    741             0.1 * np.ones((batch_size, num_proj), dtype=np.float32),
    742             dtype=dtypes.float32)
    743         init_state = rnn_cell.LSTMStateTuple(state_value_c, state_value_h)
    744         output, state = cell(inputs, init_state)
    745         sess.run([variables.global_variables_initializer()])
    746         res = sess.run([output, state])
    747 
    748         # This is a smoke test: Only making sure expected values not change.
    749         self.assertEqual(len(res), 2)
    750         self.assertAllClose(res[0], expected_output)
    751         # There should be 2 states in the tuple.
    752         self.assertEqual(len(res[1]), 2)
    753         # Checking the shape of each state to be batch_size * num_units
    754         new_c, new_h = res[1]
    755         self.assertEqual(new_c.shape[0], batch_size)
    756         self.assertEqual(new_c.shape[1], num_units)
    757         self.assertEqual(new_h.shape[0], batch_size)
    758         self.assertEqual(new_h.shape[1], num_proj)
    759         self.assertAllClose(np.concatenate(res[1], axis=1), expected_state)
    760 
    761   def testUGRNNCell(self):
    762     num_units = 2
    763     batch_size = 3
    764     expected_state_and_output = np.array(
    765         [[0.13752282, 0.13752282], [0.10545051, 0.10545051],
    766          [0.10074195, 0.10074195]],
    767         dtype=np.float32)
    768     with self.test_session() as sess:
    769       with variable_scope.variable_scope(
    770           "ugrnn_cell_test", initializer=init_ops.constant_initializer(0.5)):
    771         cell = contrib_rnn_cell.UGRNNCell(num_units=num_units)
    772         inputs = constant_op.constant(
    773             np.array(
    774                 [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
    775                 dtype=np.float32),
    776             dtype=dtypes.float32)
    777         init_state = constant_op.constant(
    778             0.1 * np.ones((batch_size, num_units), dtype=np.float32),
    779             dtype=dtypes.float32)
    780         output, state = cell(inputs, init_state)
    781         sess.run([variables.global_variables_initializer()])
    782         res = sess.run([output, state])
    783         # This is a smoke test: Only making sure expected values didn't change.
    784         self.assertEqual(len(res), 2)
    785         self.assertAllClose(res[0], expected_state_and_output)
    786         self.assertAllClose(res[1], expected_state_and_output)
    787 
    788   def testIntersectionRNNCell(self):
    789     num_units = 2
    790     batch_size = 3
    791     expected_state = np.array(
    792         [[0.13752282, 0.13752282], [0.10545051, 0.10545051],
    793          [0.10074195, 0.10074195]],
    794         dtype=np.float32)
    795     expected_output = np.array(
    796         [[2.00431061, 2.00431061], [4.00060606, 4.00060606],
    797          [6.00008249, 6.00008249]],
    798         dtype=np.float32)
    799     with self.test_session() as sess:
    800       with variable_scope.variable_scope(
    801           "intersection_rnn_cell_test",
    802           initializer=init_ops.constant_initializer(0.5)):
    803         cell = contrib_rnn_cell.IntersectionRNNCell(
    804             num_units=num_units, num_in_proj=num_units)
    805         inputs = constant_op.constant(
    806             np.array(
    807                 [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
    808                 dtype=np.float32),
    809             dtype=dtypes.float32)
    810         init_state = constant_op.constant(
    811             0.1 * np.ones((batch_size, num_units), dtype=np.float32),
    812             dtype=dtypes.float32)
    813         output, state = cell(inputs, init_state)
    814         sess.run([variables.global_variables_initializer()])
    815         res = sess.run([output, state])
    816         # This is a smoke test: Only making sure expected values didn't change.
    817         self.assertEqual(len(res), 2)
    818         self.assertAllClose(res[0], expected_output)
    819         self.assertAllClose(res[1], expected_state)
    820 
    821   def testIntersectionRNNCellFailure(self):
    822     num_units = 2
    823     batch_size = 3
    824     cell = contrib_rnn_cell.IntersectionRNNCell(num_units=num_units)
    825     inputs = constant_op.constant(
    826         np.array(
    827             [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
    828             dtype=np.float32),
    829         dtype=dtypes.float32)
    830     init_state = constant_op.constant(
    831         0.1 * np.ones((batch_size, num_units), dtype=np.float32),
    832         dtype=dtypes.float32)
    833     with self.assertRaisesRegexp(ValueError,
    834                                  "Must have input size == output size for "
    835                                  "Intersection RNN. To fix, num_in_proj should "
    836                                  "be set to num_units at cell init."):
    837       cell(inputs, init_state)
    838 
    839   def testPhasedLSTMCell(self):
    840     with self.test_session() as sess:
    841       num_units = 2
    842       batch_size = 3
    843       input_size = 4
    844       expected_state_c = np.array(
    845           [[6.450831e-04, 4.697885e-04], [9.862894e-05, 7.212213e-04],
    846            [4.401947e-04, 9.143004e-04]],
    847           dtype=np.float32)
    848       expected_state_h = np.array(
    849           [[4.621217e-04, 3.365449e-04], [7.438179e-05, 5.439147e-04],
    850            [3.347936e-04, 6.953785e-04]],
    851           dtype=np.float32)
    852       with variable_scope.variable_scope(
    853           "root", initializer=init_ops.constant_initializer(0.5)):
    854         t = array_ops.zeros([batch_size, 1], dtype=dtypes.float64)
    855         x = array_ops.zeros([batch_size, input_size])
    856         c0 = array_ops.zeros([batch_size, 2])
    857         h0 = array_ops.zeros([batch_size, 2])
    858         state0 = rnn_cell.LSTMStateTuple(c0, h0)
    859         output, state = contrib_rnn_cell.PhasedLSTMCell(num_units=num_units)(
    860             (t, x), state0)
    861         sess.run([variables.global_variables_initializer()])
    862         res = sess.run(
    863             [output, state], {
    864                 t.name:
    865                     np.array([[1.], [2.], [3.]]),
    866                 x.name:
    867                     np.array([[1., 1., 1., 1.], [2., 2., 2., 2.],
    868                               [3., 3., 3., 3.]]),
    869             })
    870         # This is a smoke test, making sure expected values are unchanged.
    871         self.assertEqual(len(res), 2)
    872         self.assertAllClose(res[0], res[1].h)
    873         self.assertAllClose(res[1].c, expected_state_c)
    874         self.assertAllClose(res[1].h, expected_state_h)
    875 
    876   def testConv1DLSTMCell(self):
    877     with self.test_session() as sess:
    878       shape = [2, 1]
    879       filter_size = [3]
    880       num_features = 1
    881       batch_size = 2
    882       expected_state_c = np.array(
    883           [[[1.4375670191], [1.4375670191]], [[2.7542609292], [2.7542609292]]],
    884           dtype=np.float32)
    885       expected_state_h = np.array(
    886           [[[0.6529865603], [0.6529865603]], [[0.8736877431], [0.8736877431]]],
    887           dtype=np.float32)
    888       with variable_scope.variable_scope(
    889           "root", initializer=init_ops.constant_initializer(1.0 / 2.0)):
    890         x = array_ops.placeholder(dtypes.float32, [None, None, 1])
    891         cell = contrib_rnn_cell.Conv1DLSTMCell(
    892             input_shape=shape,
    893             kernel_shape=filter_size,
    894             output_channels=num_features)
    895         hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32)
    896         output, state = cell(x, hidden)
    897 
    898         sess.run([variables.global_variables_initializer()])
    899         res = sess.run(
    900             [output, state], {
    901                 hidden[0].name: np.array([[[1.], [1.]], [[2.], [2.]]]),
    902                 x.name: np.array([[[1.], [1.]], [[2.], [2.]]]),
    903             })
    904         # This is a smoke test, making sure expected values are unchanged.
    905         self.assertEqual(len(res), 2)
    906         self.assertAllClose(res[0], res[1].h)
    907         self.assertAllClose(res[1].c, expected_state_c)
    908         self.assertAllClose(res[1].h, expected_state_h)
    909 
    910   def testConv2DLSTMCell(self):
    911     with self.test_session() as sess:
    912       shape = [2, 2, 1]
    913       filter_size = [3, 3]
    914       num_features = 1
    915       batch_size = 2
    916       expected_state_c = np.array(
    917           [[[[1.4375670191], [1.4375670191]], [[1.4375670191], [1.4375670191]]],
    918            [[[2.7542609292], [2.7542609292]], [[2.7542609292], [2.7542609292]]
    919            ]],
    920           dtype=np.float32)
    921       expected_state_h = np.array(
    922           [[[[0.6529865603], [0.6529865603]], [[0.6529865603], [0.6529865603]]],
    923            [[[0.8736877431], [0.8736877431]], [[0.8736877431], [0.8736877431]]
    924            ]],
    925           dtype=np.float32)
    926       with variable_scope.variable_scope(
    927           "root", initializer=init_ops.constant_initializer(1.0 / 4.0)):
    928         x = array_ops.placeholder(dtypes.float32, [None, None, None, 1])
    929         cell = contrib_rnn_cell.Conv2DLSTMCell(
    930             input_shape=shape,
    931             kernel_shape=filter_size,
    932             output_channels=num_features)
    933         hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32)
    934         output, state = cell(x, hidden)
    935 
    936         sess.run([variables.global_variables_initializer()])
    937         res = sess.run(
    938             [output, state], {
    939                 hidden[0].name:
    940                     np.array([[[[1.], [1.]], [[1.], [1.]]], [[[2.], [2.]],
    941                                                              [[2.], [2.]]]]),
    942                 x.name:
    943                     np.array([[[[1.], [1.]], [[1.], [1.]]], [[[2.], [2.]],
    944                                                              [[2.], [2.]]]]),
    945             })
    946         # This is a smoke test, making sure expected values are unchanged.
    947         self.assertEqual(len(res), 2)
    948         self.assertAllClose(res[0], res[1].h)
    949         self.assertAllClose(res[1].c, expected_state_c)
    950         self.assertAllClose(res[1].h, expected_state_h)
    951 
    952   def testConv3DLSTMCell(self):
    953     with self.test_session() as sess:
    954       shape = [2, 2, 2, 1]
    955       filter_size = [3, 3, 3]
    956       num_features = 1
    957       batch_size = 2
    958       expected_state_c = np.array(
    959           [[[[[1.4375670191], [1.4375670191]], [[1.4375670191], [1.4375670191]]
    960             ], [[[1.4375670191], [1.4375670191]], [[1.4375670191],
    961                                                    [1.4375670191]]]],
    962            [[[[2.7542609292], [2.7542609292]], [[2.7542609292], [2.7542609292]]
    963             ], [[[2.7542609292], [2.7542609292]], [[2.7542609292],
    964                                                    [2.7542609292]]]]],
    965           dtype=np.float32)
    966       expected_state_h = np.array(
    967           [[[[[0.6529865603], [0.6529865603]], [[0.6529865603], [0.6529865603]]
    968             ], [[[0.6529865603], [0.6529865603]], [[0.6529865603],
    969                                                    [0.6529865603]]]],
    970            [[[[0.8736877431], [0.8736877431]], [[0.8736877431], [0.8736877431]]
    971             ], [[[0.8736877431], [0.8736877431]], [[0.8736877431],
    972                                                    [0.8736877431]]]]],
    973           dtype=np.float32)
    974       with variable_scope.variable_scope(
    975           "root", initializer=init_ops.constant_initializer(1.0 / 8.0)):
    976         x = array_ops.placeholder(dtypes.float32, [None, None, None, None, 1])
    977         cell = contrib_rnn_cell.Conv3DLSTMCell(
    978             input_shape=shape,
    979             kernel_shape=filter_size,
    980             output_channels=num_features)
    981         hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32)
    982         output, state = cell(x, hidden)
    983 
    984         sess.run([variables.global_variables_initializer()])
    985         res = sess.run(
    986             [output, state], {
    987                 hidden[0].name:
    988                     np.array([[[[[1.], [1.]], [[1.], [1.]]], [[[1.], [1.]], [[
    989                         1.
    990                     ], [1.]]]], [[[[2.], [2.]], [[2.], [2.]]],
    991                                  [[[2.], [2.]], [[2.], [2.]]]]]),
    992                 x.name:
    993                     np.array([[[[[1.], [1.]], [[1.], [1.]]], [[[1.], [1.]], [[
    994                         1.
    995                     ], [1.]]]], [[[[2.], [2.]], [[2.], [2.]]], [[[2.], [2.]],
    996                                                                 [[2.], [2.]]]]])
    997             })
    998         # This is a smoke test, making sure expected values are unchanged.
    999         self.assertEqual(len(res), 2)
   1000         self.assertAllClose(res[0], res[1].h)
   1001         self.assertAllClose(res[1].c, expected_state_c)
   1002         self.assertAllClose(res[1].h, expected_state_h)
   1003 
   1004   def testHighwayWrapper(self):
   1005     with self.test_session() as sess:
   1006       with variable_scope.variable_scope(
   1007           "base_cell", initializer=init_ops.constant_initializer(0.5)):
   1008         x = array_ops.zeros([1, 3])
   1009         m = array_ops.zeros([1, 3])
   1010         base_cell = rnn_cell.GRUCell(3)
   1011         g, m_new = base_cell(x, m)
   1012       with variable_scope.variable_scope(
   1013           "hw_cell", initializer=init_ops.constant_initializer(0.5)):
   1014         hw_cell = contrib_rnn_cell.HighwayWrapper(
   1015             rnn_cell.GRUCell(3), carry_bias_init=-100.0)
   1016         g_res, m_new_res = hw_cell(x, m)
   1017         sess.run([variables.global_variables_initializer()])
   1018       res = sess.run([g, g_res, m_new, m_new_res], {
   1019           x: np.array([[1., 1., 1.]]),
   1020           m: np.array([[0.1, 0.1, 0.1]])
   1021       })
   1022       # As carry_bias_init is very negative, the carry gate is 'open' and the
   1023       # transform gate is 'closed'. This means the output equals the input.
   1024       self.assertAllClose(res[1], res[0])
   1025       # States are left untouched
   1026       self.assertAllClose(res[2], res[3])
   1027 
   1028   def testGLSTMCell(self):
   1029     # Ensure that G-LSTM matches LSTM when number_of_groups = 1
   1030     batch_size = 2
   1031     num_units = 4
   1032     number_of_groups = 1
   1033 
   1034     with self.test_session() as sess:
   1035       with variable_scope.variable_scope(
   1036           "root1", initializer=init_ops.constant_initializer(0.5)):
   1037         x = array_ops.ones([batch_size, num_units])
   1038         # When number_of_groups = 1, G-LSTM is equivalent to regular LSTM
   1039         gcell = contrib_rnn_cell.GLSTMCell(
   1040             num_units=num_units, number_of_groups=number_of_groups)
   1041         cell = rnn_cell.LSTMCell(num_units=num_units)
   1042         self.assertTrue(isinstance(gcell.state_size, tuple))
   1043         zero_state = gcell.zero_state(
   1044             batch_size=batch_size, dtype=dtypes.float32)
   1045         gh, gs = gcell(x, zero_state)
   1046         h, g = cell(x, zero_state)
   1047 
   1048         sess.run([variables.global_variables_initializer()])
   1049         glstm_result = sess.run([gh, gs])
   1050         lstm_result = sess.run([h, g])
   1051 
   1052         self.assertAllClose(glstm_result[0], lstm_result[0], 1e-5)
   1053         self.assertAllClose(glstm_result[1], lstm_result[1], 1e-5)
   1054 
   1055     # Test that G-LSTM subgroup act like corresponding sub-LSTMs
   1056     batch_size = 2
   1057     num_units = 4
   1058     number_of_groups = 2
   1059 
   1060     with self.test_session() as sess:
   1061       with variable_scope.variable_scope(
   1062           "root2", initializer=init_ops.constant_initializer(0.5)):
   1063         # input for G-LSTM with 2 groups
   1064         glstm_input = array_ops.ones([batch_size, num_units])
   1065         gcell = contrib_rnn_cell.GLSTMCell(
   1066             num_units=num_units, number_of_groups=number_of_groups)
   1067         gcell_zero_state = gcell.zero_state(
   1068             batch_size=batch_size, dtype=dtypes.float32)
   1069         gh, gs = gcell(glstm_input, gcell_zero_state)
   1070 
   1071         # input for LSTM cell simulating single G-LSTM group
   1072         lstm_input = array_ops.ones([batch_size, num_units / number_of_groups])
   1073         # note division by number_of_groups. This cell one simulates G-LSTM group
   1074         cell = rnn_cell.LSTMCell(num_units=int(num_units / number_of_groups))
   1075         cell_zero_state = cell.zero_state(
   1076             batch_size=batch_size, dtype=dtypes.float32)
   1077         h, g = cell(lstm_input, cell_zero_state)
   1078 
   1079         sess.run([variables.global_variables_initializer()])
   1080         [gh_res, h_res] = sess.run([gh, h])
   1081         self.assertAllClose(gh_res[:, 0:int(num_units / number_of_groups)],
   1082                             h_res, 1e-5)
   1083         self.assertAllClose(gh_res[:, int(num_units / number_of_groups):],
   1084                             h_res, 1e-5)
   1085 
   1086 
   1087 class LayerNormBasicLSTMCellTest(test.TestCase):
   1088 
   1089   # NOTE: all the values in the current test case have been calculated.
   1090 
   1091   def testBasicLSTMCell(self):
   1092     with self.test_session() as sess:
   1093       with variable_scope.variable_scope(
   1094           "root", initializer=init_ops.constant_initializer(0.5)):
   1095         x = array_ops.zeros([1, 2])
   1096         c0 = array_ops.zeros([1, 2])
   1097         h0 = array_ops.zeros([1, 2])
   1098         state0 = rnn_cell.LSTMStateTuple(c0, h0)
   1099         c1 = array_ops.zeros([1, 2])
   1100         h1 = array_ops.zeros([1, 2])
   1101         state1 = rnn_cell.LSTMStateTuple(c1, h1)
   1102         state = (state0, state1)
   1103         single_cell = lambda: contrib_rnn_cell.LayerNormBasicLSTMCell(2)
   1104         cell = rnn_cell.MultiRNNCell([single_cell() for _ in range(2)])
   1105         g, out_m = cell(x, state)
   1106         sess.run([variables.global_variables_initializer()])
   1107         res = sess.run(
   1108             [g, out_m], {
   1109                 x.name: np.array([[1., 1.]]),
   1110                 c0.name: 0.1 * np.asarray([[0, 1]]),
   1111                 h0.name: 0.1 * np.asarray([[2, 3]]),
   1112                 c1.name: 0.1 * np.asarray([[4, 5]]),
   1113                 h1.name: 0.1 * np.asarray([[6, 7]]),
   1114             })
   1115 
   1116         expected_h = np.array([[-0.38079708, 0.38079708]])
   1117         expected_state0_c = np.array([[-1.0, 1.0]])
   1118         expected_state0_h = np.array([[-0.38079708, 0.38079708]])
   1119         expected_state1_c = np.array([[-1.0, 1.0]])
   1120         expected_state1_h = np.array([[-0.38079708, 0.38079708]])
   1121 
   1122         actual_h = res[0]
   1123         actual_state0_c = res[1][0].c
   1124         actual_state0_h = res[1][0].h
   1125         actual_state1_c = res[1][1].c
   1126         actual_state1_h = res[1][1].h
   1127 
   1128         self.assertAllClose(actual_h, expected_h, 1e-5)
   1129         self.assertAllClose(expected_state0_c, actual_state0_c, 1e-5)
   1130         self.assertAllClose(expected_state0_h, actual_state0_h, 1e-5)
   1131         self.assertAllClose(expected_state1_c, actual_state1_c, 1e-5)
   1132         self.assertAllClose(expected_state1_h, actual_state1_h, 1e-5)
   1133 
   1134       with variable_scope.variable_scope(
   1135           "other", initializer=init_ops.constant_initializer(0.5)):
   1136         x = array_ops.zeros(
   1137             [1, 3])  # Test BasicLSTMCell with input_size != num_units.
   1138         c = array_ops.zeros([1, 2])
   1139         h = array_ops.zeros([1, 2])
   1140         state = rnn_cell.LSTMStateTuple(c, h)
   1141         cell = contrib_rnn_cell.LayerNormBasicLSTMCell(2)
   1142         g, out_m = cell(x, state)
   1143         sess.run([variables.global_variables_initializer()])
   1144         res = sess.run(
   1145             [g, out_m], {
   1146                 x.name: np.array([[1., 1., 1.]]),
   1147                 c.name: 0.1 * np.asarray([[0, 1]]),
   1148                 h.name: 0.1 * np.asarray([[2, 3]]),
   1149             })
   1150 
   1151         expected_h = np.array([[-0.38079708, 0.38079708]])
   1152         expected_c = np.array([[-1.0, 1.0]])
   1153         self.assertEqual(len(res), 2)
   1154         self.assertAllClose(res[0], expected_h, 1e-5)
   1155         self.assertAllClose(res[1].c, expected_c, 1e-5)
   1156         self.assertAllClose(res[1].h, expected_h, 1e-5)
   1157 
   1158   def testBasicLSTMCellWithoutNorm(self):
   1159     """Tests that BasicLSTMCell with layer_norm=False."""
   1160     with self.test_session() as sess:
   1161       with variable_scope.variable_scope(
   1162           "root", initializer=init_ops.constant_initializer(0.5)):
   1163         x = array_ops.zeros([1, 2])
   1164         c0 = array_ops.zeros([1, 2])
   1165         h0 = array_ops.zeros([1, 2])
   1166         state0 = rnn_cell.LSTMStateTuple(c0, h0)
   1167         c1 = array_ops.zeros([1, 2])
   1168         h1 = array_ops.zeros([1, 2])
   1169         state1 = rnn_cell.LSTMStateTuple(c1, h1)
   1170         state = (state0, state1)
   1171         single_cell = lambda: contrib_rnn_cell.LayerNormBasicLSTMCell(2, layer_norm=False)
   1172         cell = rnn_cell.MultiRNNCell([single_cell() for _ in range(2)])
   1173         g, out_m = cell(x, state)
   1174         sess.run([variables.global_variables_initializer()])
   1175         res = sess.run(
   1176             [g, out_m], {
   1177                 x.name: np.array([[1., 1.]]),
   1178                 c0.name: 0.1 * np.asarray([[0, 1]]),
   1179                 h0.name: 0.1 * np.asarray([[2, 3]]),
   1180                 c1.name: 0.1 * np.asarray([[4, 5]]),
   1181                 h1.name: 0.1 * np.asarray([[6, 7]]),
   1182             })
   1183 
   1184         expected_h = np.array([[0.70230919, 0.72581059]])
   1185         expected_state0_c = np.array([[0.8020075, 0.89599884]])
   1186         expected_state0_h = np.array([[0.56668288, 0.60858738]])
   1187         expected_state1_c = np.array([[1.17500675, 1.26892781]])
   1188         expected_state1_h = np.array([[0.70230919, 0.72581059]])
   1189 
   1190         actual_h = res[0]
   1191         actual_state0_c = res[1][0].c
   1192         actual_state0_h = res[1][0].h
   1193         actual_state1_c = res[1][1].c
   1194         actual_state1_h = res[1][1].h
   1195 
   1196         self.assertAllClose(actual_h, expected_h, 1e-5)
   1197         self.assertAllClose(expected_state0_c, actual_state0_c, 1e-5)
   1198         self.assertAllClose(expected_state0_h, actual_state0_h, 1e-5)
   1199         self.assertAllClose(expected_state1_c, actual_state1_c, 1e-5)
   1200         self.assertAllClose(expected_state1_h, actual_state1_h, 1e-5)
   1201 
   1202       with variable_scope.variable_scope(
   1203           "other", initializer=init_ops.constant_initializer(0.5)) as vs:
   1204         x = array_ops.zeros(
   1205             [1, 3])  # Test BasicLSTMCell with input_size != num_units.
   1206         c = array_ops.zeros([1, 2])
   1207         h = array_ops.zeros([1, 2])
   1208         state = rnn_cell.LSTMStateTuple(c, h)
   1209         cell = contrib_rnn_cell.LayerNormBasicLSTMCell(2, layer_norm=False)
   1210         g, out_m = cell(x, state)
   1211         sess.run([variables.global_variables_initializer()])
   1212         res = sess.run(
   1213             [g, out_m], {
   1214                 x.name: np.array([[1., 1., 1.]]),
   1215                 c.name: 0.1 * np.asarray([[0, 1]]),
   1216                 h.name: 0.1 * np.asarray([[2, 3]]),
   1217             })
   1218 
   1219         expected_h = np.array([[0.64121795, 0.68166804]])
   1220         expected_c = np.array([[0.88477188, 0.98103917]])
   1221         self.assertEqual(len(res), 2)
   1222         self.assertAllClose(res[0], expected_h, 1e-5)
   1223         self.assertAllClose(res[1].c, expected_c, 1e-5)
   1224         self.assertAllClose(res[1].h, expected_h, 1e-5)
   1225 
   1226   def testBasicLSTMCellWithStateTuple(self):
   1227     with self.test_session() as sess:
   1228       with variable_scope.variable_scope(
   1229           "root", initializer=init_ops.constant_initializer(0.5)):
   1230         x = array_ops.zeros([1, 2])
   1231         c0 = array_ops.zeros([1, 2])
   1232         h0 = array_ops.zeros([1, 2])
   1233         state0 = rnn_cell.LSTMStateTuple(c0, h0)
   1234         c1 = array_ops.zeros([1, 2])
   1235         h1 = array_ops.zeros([1, 2])
   1236         state1 = rnn_cell.LSTMStateTuple(c1, h1)
   1237         cell = rnn_cell.MultiRNNCell(
   1238             [contrib_rnn_cell.LayerNormBasicLSTMCell(2) for _ in range(2)])
   1239         h, (s0, s1) = cell(x, (state0, state1))
   1240         sess.run([variables.global_variables_initializer()])
   1241         res = sess.run(
   1242             [h, s0, s1], {
   1243                 x.name: np.array([[1., 1.]]),
   1244                 c0.name: 0.1 * np.asarray([[0, 1]]),
   1245                 h0.name: 0.1 * np.asarray([[2, 3]]),
   1246                 c1.name: 0.1 * np.asarray([[4, 5]]),
   1247                 h1.name: 0.1 * np.asarray([[6, 7]]),
   1248             })
   1249 
   1250         expected_h = np.array([[-0.38079708, 0.38079708]])
   1251         expected_h0 = np.array([[-0.38079708, 0.38079708]])
   1252         expected_c0 = np.array([[-1.0, 1.0]])
   1253         expected_h1 = np.array([[-0.38079708, 0.38079708]])
   1254         expected_c1 = np.array([[-1.0, 1.0]])
   1255 
   1256         self.assertEqual(len(res), 3)
   1257         self.assertAllClose(res[0], expected_h, 1e-5)
   1258         self.assertAllClose(res[1].c, expected_c0, 1e-5)
   1259         self.assertAllClose(res[1].h, expected_h0, 1e-5)
   1260         self.assertAllClose(res[2].c, expected_c1, 1e-5)
   1261         self.assertAllClose(res[2].h, expected_h1, 1e-5)
   1262 
   1263   def testBasicLSTMCellWithStateTupleLayerNorm(self):
   1264     """The results of LSTMCell and LayerNormBasicLSTMCell should be the same."""
   1265     with self.test_session() as sess:
   1266       with variable_scope.variable_scope(
   1267           "root", initializer=init_ops.constant_initializer(0.5)):
   1268         x = array_ops.zeros([1, 2])
   1269         c0 = array_ops.zeros([1, 2])
   1270         h0 = array_ops.zeros([1, 2])
   1271         state0 = rnn_cell_impl.LSTMStateTuple(c0, h0)
   1272         c1 = array_ops.zeros([1, 2])
   1273         h1 = array_ops.zeros([1, 2])
   1274         state1 = rnn_cell_impl.LSTMStateTuple(c1, h1)
   1275         cell = rnn_cell_impl.MultiRNNCell([
   1276             contrib_rnn_cell.LayerNormLSTMCell(
   1277                 2, layer_norm=True, norm_gain=1.0, norm_shift=0.0)
   1278             for _ in range(2)
   1279         ])
   1280         h, (s0, s1) = cell(x, (state0, state1))
   1281         sess.run([variables.global_variables_initializer()])
   1282         res = sess.run(
   1283             [h, s0, s1], {
   1284                 x.name: np.array([[1., 1.]]),
   1285                 c0.name: 0.1 * np.asarray([[0, 1]]),
   1286                 h0.name: 0.1 * np.asarray([[2, 3]]),
   1287                 c1.name: 0.1 * np.asarray([[4, 5]]),
   1288                 h1.name: 0.1 * np.asarray([[6, 7]]),
   1289             })
   1290 
   1291         expected_h = np.array([[-0.38079708, 0.38079708]])
   1292         expected_h0 = np.array([[-0.38079708, 0.38079708]])
   1293         expected_c0 = np.array([[-1.0, 1.0]])
   1294         expected_h1 = np.array([[-0.38079708, 0.38079708]])
   1295         expected_c1 = np.array([[-1.0, 1.0]])
   1296 
   1297         self.assertEqual(len(res), 3)
   1298         self.assertAllClose(res[0], expected_h, 1e-5)
   1299         self.assertAllClose(res[1].c, expected_c0, 1e-5)
   1300         self.assertAllClose(res[1].h, expected_h0, 1e-5)
   1301         self.assertAllClose(res[2].c, expected_c1, 1e-5)
   1302         self.assertAllClose(res[2].h, expected_h1, 1e-5)
   1303 
   1304   def testBasicLSTMCellWithDropout(self):
   1305 
   1306     def _is_close(x, y, digits=4):
   1307       delta = x - y
   1308       return delta < 10**(-digits)
   1309 
   1310     def _is_close_in(x, items, digits=4):
   1311       for i in items:
   1312         if _is_close(x, i, digits):
   1313           return True
   1314       return False
   1315 
   1316     keep_prob = 0.5
   1317     c_high = 2.9998924946
   1318     c_low = 0.999983298578
   1319     h_low = 0.761552567265
   1320     h_high = 0.995008519604
   1321     num_units = 5
   1322     allowed_low = [1, 2, 3]
   1323 
   1324     with self.test_session() as sess:
   1325       with variable_scope.variable_scope(
   1326           "other", initializer=init_ops.constant_initializer(1)):
   1327         x = array_ops.zeros([1, 5])
   1328         c = array_ops.zeros([1, 5])
   1329         h = array_ops.zeros([1, 5])
   1330         state = rnn_cell.LSTMStateTuple(c, h)
   1331         cell = contrib_rnn_cell.LayerNormBasicLSTMCell(
   1332             num_units, layer_norm=False, dropout_keep_prob=keep_prob)
   1333 
   1334         g, s = cell(x, state)
   1335         sess.run([variables.global_variables_initializer()])
   1336         res = sess.run(
   1337             [g, s], {
   1338                 x.name: np.ones([1, 5]),
   1339                 c.name: np.ones([1, 5]),
   1340                 h.name: np.ones([1, 5]),
   1341             })
   1342 
   1343         # Since the returned tensors are of size [1,n]
   1344         # get the first component right now.
   1345         actual_h = res[0][0]
   1346         actual_state_c = res[1].c[0]
   1347         actual_state_h = res[1].h[0]
   1348 
   1349         # For each item in `c` (the cell inner state) check that
   1350         # it is equal to one of the allowed values `c_high` (not
   1351         # dropped out) or `c_low` (dropped out) and verify that the
   1352         # corresponding item in `h` (the cell activation) is coherent.
   1353         # Count the dropped activations and check that their number is
   1354         # coherent with the dropout probability.
   1355         dropped_count = 0
   1356         self.assertTrue((actual_h == actual_state_h).all())
   1357         for citem, hitem in zip(actual_state_c, actual_state_h):
   1358           self.assertTrue(_is_close_in(citem, [c_low, c_high]))
   1359           if _is_close(citem, c_low):
   1360             self.assertTrue(_is_close(hitem, h_low))
   1361             dropped_count += 1
   1362           elif _is_close(citem, c_high):
   1363             self.assertTrue(_is_close(hitem, h_high))
   1364         self.assertIn(dropped_count, allowed_low)
   1365 
   1366 
   1367 def _create_multi_lstm_cell_ops(batch_size, num_units, input_depth, num_layers,
   1368                                 max_time, compiled):
   1369   with variable_scope.variable_scope(
   1370       "root",
   1371       initializer=init_ops.random_uniform_initializer(-0.1, 0.1, seed=2)):
   1372     inputs = variable_scope.get_variable(
   1373         "inputs",
   1374         initializer=random_ops.random_uniform(
   1375             (max_time, batch_size, input_depth), seed=1))
   1376     maybe_xla = lambda c: contrib_rnn_cell.CompiledWrapper(c) if compiled else c
   1377     cell = rnn_cell.MultiRNNCell(
   1378         [maybe_xla(rnn_cell.LSTMCell(num_units)) for _ in range(num_layers)])
   1379     initial_state = cell.zero_state(batch_size=batch_size, dtype=dtypes.float32)
   1380     outputs, final_state = rnn.dynamic_rnn(
   1381         cell=cell, inputs=inputs, initial_state=initial_state, time_major=True)
   1382     flat_final_state = nest.flatten(final_state)
   1383     trainable_variables = variables.trainable_variables()
   1384     outputs_grad = gradients_impl.gradients(
   1385         [outputs], trainable_variables + [inputs] + nest.flatten(initial_state))
   1386     final_state_grad = gradients_impl.gradients(
   1387         flat_final_state,
   1388         trainable_variables + [inputs] + nest.flatten(initial_state))
   1389 
   1390     return {
   1391         "outputs": outputs,
   1392         "final_state": flat_final_state,
   1393         "outputs_grad": outputs_grad,
   1394         "final_state_grad": final_state_grad
   1395     }
   1396 
   1397 
   1398 class CompiledWrapperTest(test.TestCase):
   1399 
   1400   def testMultiRNNCellWithLSTMCellAndXLA(self):
   1401     # TODO(b/34735319): Don't run this test if XLA is not available.
   1402     batch_size = 16
   1403     num_units = 32
   1404     input_depth = 12
   1405     num_layers = 2
   1406     max_time = 20
   1407 
   1408     atol = 1e-5
   1409 
   1410     random_seed.set_random_seed(1234)
   1411     with self.test_session(graph=ops.Graph()) as sess:
   1412       xla_ops = _create_multi_lstm_cell_ops(
   1413           batch_size=batch_size,
   1414           num_units=num_units,
   1415           input_depth=input_depth,
   1416           num_layers=num_layers,
   1417           max_time=max_time,
   1418           compiled=True)
   1419       sess.run([variables.global_variables_initializer()])
   1420       xla_results = sess.run(xla_ops)
   1421 
   1422     random_seed.set_random_seed(1234)
   1423     with self.test_session(graph=ops.Graph()) as sess:
   1424       non_xla_ops = _create_multi_lstm_cell_ops(
   1425           batch_size=batch_size,
   1426           num_units=num_units,
   1427           input_depth=input_depth,
   1428           num_layers=num_layers,
   1429           max_time=max_time,
   1430           compiled=False)
   1431       sess.run([variables.global_variables_initializer()])
   1432       non_xla_results = sess.run(non_xla_ops)
   1433 
   1434     self.assertAllClose(
   1435         non_xla_results["outputs"], xla_results["outputs"], atol=atol)
   1436 
   1437     for xla_value, non_xla_value in zip(xla_results["final_state"],
   1438                                         non_xla_results["final_state"]):
   1439       self.assertAllClose(xla_value, non_xla_value, atol=atol)
   1440 
   1441     for xla_g, non_xla_g in zip(xla_results["outputs_grad"],
   1442                                 non_xla_results["outputs_grad"]):
   1443       self.assertAllClose(xla_g, non_xla_g, atol=atol)
   1444 
   1445     for xla_g, non_xla_g in zip(xla_results["final_state_grad"],
   1446                                 non_xla_results["final_state_grad"]):
   1447       self.assertAllClose(xla_g, non_xla_g, atol=atol)
   1448 
   1449   def testMultiRNNCellWithStateTuple(self):
   1450     with self.test_session() as sess:
   1451       with variable_scope.variable_scope(
   1452           "root", initializer=init_ops.constant_initializer(0.5)):
   1453         x = array_ops.zeros([1, 2])
   1454         m_bad = array_ops.zeros([1, 4])
   1455         m_good = (array_ops.zeros([1, 2]), array_ops.zeros([1, 2]))
   1456 
   1457         # Test incorrectness of state
   1458         with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"):
   1459           rnn_cell.MultiRNNCell(
   1460               [rnn_cell.GRUCell(2) for _ in range(2)],
   1461               state_is_tuple=True)(x, m_bad)
   1462 
   1463         _, ml = rnn_cell.MultiRNNCell(
   1464             [rnn_cell.GRUCell(2) for _ in range(2)],
   1465             state_is_tuple=True)(x, m_good)
   1466 
   1467         sess.run([variables.global_variables_initializer()])
   1468         res = sess.run(
   1469             ml, {
   1470                 x.name: np.array([[1., 1.]]),
   1471                 m_good[0].name: np.array([[0.1, 0.1]]),
   1472                 m_good[1].name: np.array([[0.1, 0.1]])
   1473             })
   1474 
   1475         # The numbers in results were not calculated, this is just a
   1476         # smoke test.  However, these numbers should match those of
   1477         # the test testMultiRNNCell.
   1478         self.assertAllClose(res[0], [[0.175991, 0.175991]])
   1479         self.assertAllClose(res[1], [[0.13248, 0.13248]])
   1480 
   1481 
   1482 class BenchmarkLSTMCellXLA(test.Benchmark):
   1483 
   1484   def benchmarkDynamicRNNWithMultiLSTMCell(self):
   1485     num_layers = 3
   1486     max_time = 50
   1487     print("benchmarkDynamicRNNWithMultiLSTMCell")
   1488     print("\t" + "\t".join([
   1489         "inter_th", "intra_th", "batch_size", "num_units", "input_depth",
   1490         "device", "compiled", "wall_time"
   1491     ]))
   1492 
   1493     warmup_run = True
   1494     for (threads, device, num_units, batch_size, input_depth,
   1495          compiled) in itertools.product([{
   1496              "inter": 0,
   1497              "intra": 0
   1498          }, {
   1499              "inter": 1,
   1500              "intra": 4
   1501          }], ["cpu", "gpu"], [32, 512], [1, 32, 256], [32, 512], [False, True]):
   1502       if threads["inter"] != 0:
   1503         # We only care about testing inter/intra op limitations on
   1504         # CPU with small batch size, to mimic embedded devices.
   1505         if device != "cpu" or batch_size != 1:
   1506           continue
   1507       if device == "cpu" and batch_size > 32:
   1508         continue
   1509       random_seed.set_random_seed(1234)
   1510       config = config_pb2.ConfigProto(
   1511           inter_op_parallelism_threads=threads["inter"],
   1512           intra_op_parallelism_threads=threads["intra"],
   1513           allow_soft_placement=False)
   1514       with session.Session(config=config, graph=ops.Graph()) as sess:
   1515         with ops.device("/%s:0" % device):
   1516           ops_dict = _create_multi_lstm_cell_ops(
   1517               batch_size=batch_size,
   1518               num_units=num_units,
   1519               input_depth=input_depth,
   1520               num_layers=num_layers,
   1521               max_time=max_time,
   1522               compiled=compiled)
   1523         sess.run([variables.global_variables_initializer()])
   1524         all_ops = nest.flatten(ops_dict.values())
   1525         all_ops_group = control_flow_ops.group(*all_ops)
   1526         name_suffix = ("inter_th_%d_intra_th_%d_bs_%d_units_%d_inputdepth_%d"
   1527                        "_device_%s_xla_%s" %
   1528                        (threads["inter"], threads["intra"], batch_size,
   1529                         num_units, input_depth, device, compiled))
   1530         if warmup_run:
   1531           self.run_op_benchmark(
   1532               sess, all_ops_group, min_iters=30, name="ignore_warmup")
   1533           warmup_run = False
   1534         benchmark_results = self.run_op_benchmark(
   1535             sess,
   1536             all_ops_group,
   1537             min_iters=50,
   1538             name="benchmarkDynamicRNNWithMultiLSTMCell_%s" % name_suffix)
   1539         print("\t" + "\t".join([
   1540             "%s" % x
   1541             for x in [
   1542                 threads["inter"], threads["intra"], batch_size, num_units,
   1543                 input_depth, device, compiled, benchmark_results["wall_time"]
   1544             ]
   1545         ]))
   1546 
   1547 
   1548 class WeightNormLSTMCellTest(test.TestCase):
   1549   """Compared cell output with pre-calculated values."""
   1550 
   1551   def _cell_output(self, cell):
   1552     """Calculate cell output"""
   1553 
   1554     with self.test_session() as sess:
   1555       init = init_ops.constant_initializer(0.5)
   1556       with variable_scope.variable_scope("root",
   1557                                          initializer=init):
   1558         x = array_ops.zeros([1, 2])
   1559         c0 = array_ops.zeros([1, 2])
   1560         h0 = array_ops.zeros([1, 2])
   1561 
   1562         state0 = rnn_cell.LSTMStateTuple(c0, h0)
   1563 
   1564         xout, sout = cell()(x, state0)
   1565 
   1566       sess.run([variables.global_variables_initializer()])
   1567       res = sess.run([xout, sout], {
   1568           x.name: np.array([[1., 1.]]),
   1569           c0.name: 0.1 * np.asarray([[0, 1]]),
   1570           h0.name: 0.1 * np.asarray([[2, 3]]),
   1571       })
   1572 
   1573     actual_state_c = res[1].c
   1574     actual_state_h = res[1].h
   1575 
   1576     return actual_state_c, actual_state_h
   1577 
   1578   def testBasicCell(self):
   1579     """Tests cell w/o peepholes and w/o normalisation"""
   1580 
   1581     def cell():
   1582       return contrib_rnn_cell.WeightNormLSTMCell(2,
   1583                                                  norm=False,
   1584                                                  use_peepholes=False)
   1585 
   1586     actual_c, actual_h = self._cell_output(cell)
   1587 
   1588     expected_c = np.array([[0.65937078, 0.74983585]])
   1589     expected_h = np.array([[0.44923624, 0.49362513]])
   1590 
   1591     self.assertAllClose(expected_c, actual_c, 1e-5)
   1592     self.assertAllClose(expected_h, actual_h, 1e-5)
   1593 
   1594   def testNonbasicCell(self):
   1595     """Tests cell with peepholes and w/o normalisation"""
   1596 
   1597     def cell():
   1598       return contrib_rnn_cell.WeightNormLSTMCell(2,
   1599                                                  norm=False,
   1600                                                  use_peepholes=True)
   1601 
   1602     actual_c, actual_h = self._cell_output(cell)
   1603 
   1604     expected_c = np.array([[0.65937084, 0.7574988]])
   1605     expected_h = np.array([[0.4792085, 0.53470564]])
   1606 
   1607     self.assertAllClose(expected_c, actual_c, 1e-5)
   1608     self.assertAllClose(expected_h, actual_h, 1e-5)
   1609 
   1610 
   1611   def testBasicCellWithNorm(self):
   1612     """Tests cell w/o peepholes and with normalisation"""
   1613 
   1614     def cell():
   1615       return contrib_rnn_cell.WeightNormLSTMCell(2,
   1616                                                  norm=True,
   1617                                                  use_peepholes=False)
   1618 
   1619     actual_c, actual_h = self._cell_output(cell)
   1620 
   1621     expected_c = np.array([[0.50125383, 0.58805949]])
   1622     expected_h = np.array([[0.32770363, 0.37397948]])
   1623 
   1624     self.assertAllClose(expected_c, actual_c, 1e-5)
   1625     self.assertAllClose(expected_h, actual_h, 1e-5)
   1626 
   1627   def testNonBasicCellWithNorm(self):
   1628     """Tests cell with peepholes and with normalisation"""
   1629 
   1630     def cell():
   1631       return contrib_rnn_cell.WeightNormLSTMCell(2,
   1632                                                  norm=True,
   1633                                                  use_peepholes=True)
   1634 
   1635     actual_c, actual_h = self._cell_output(cell)
   1636 
   1637     expected_c = np.array([[0.50125383, 0.59587258]])
   1638     expected_h = np.array([[0.35041603, 0.40873795]])
   1639 
   1640     self.assertAllClose(expected_c, actual_c, 1e-5)
   1641     self.assertAllClose(expected_h, actual_h, 1e-5)
   1642 
   1643 if __name__ == "__main__":
   1644   test.main()
   1645