1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tensorflow.contrib.rnn.python.ops.fused_rnn_cell.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.contrib.rnn.python.ops import fused_rnn_cell 24 from tensorflow.python.framework import constant_op 25 from tensorflow.python.framework import dtypes 26 from tensorflow.python.ops import array_ops 27 from tensorflow.python.ops import gradients_impl 28 from tensorflow.python.ops import init_ops 29 from tensorflow.python.ops import rnn 30 from tensorflow.python.ops import rnn_cell 31 from tensorflow.python.ops import variable_scope 32 from tensorflow.python.ops import variables 33 from tensorflow.python.platform import test 34 35 36 class FusedRnnCellTest(test.TestCase): 37 38 def testBasicRNNFusedWrapper(self): 39 """This test checks that using a wrapper for BasicRNN works as expected.""" 40 41 with self.test_session() as sess: 42 initializer = init_ops.random_uniform_initializer( 43 -0.01, 0.01, seed=19890212) 44 cell = rnn_cell.BasicRNNCell(10) 45 batch_size = 5 46 input_size = 20 47 timelen = 15 48 inputs = constant_op.constant( 49 np.random.randn(timelen, batch_size, input_size)) 50 with variable_scope.variable_scope("basic", initializer=initializer): 51 unpacked_inputs = array_ops.unstack(inputs) 52 outputs, state = rnn.static_rnn( 53 cell, unpacked_inputs, dtype=dtypes.float64) 54 packed_outputs = array_ops.stack(outputs) 55 basic_vars = [ 56 v for v in variables.trainable_variables() 57 if v.name.startswith("basic/") 58 ] 59 sess.run([variables.global_variables_initializer()]) 60 basic_outputs, basic_state = sess.run([packed_outputs, state]) 61 basic_grads = sess.run(gradients_impl.gradients(packed_outputs, inputs)) 62 basic_wgrads = sess.run( 63 gradients_impl.gradients(packed_outputs, basic_vars)) 64 65 with variable_scope.variable_scope( 66 "fused_static", initializer=initializer): 67 fused_cell = fused_rnn_cell.FusedRNNCellAdaptor( 68 rnn_cell.BasicRNNCell(10)) 69 outputs, state = fused_cell(inputs, dtype=dtypes.float64) 70 fused_static_vars = [ 71 v for v in variables.trainable_variables() 72 if v.name.startswith("fused_static/") 73 ] 74 sess.run([variables.global_variables_initializer()]) 75 fused_static_outputs, fused_static_state = sess.run([outputs, state]) 76 fused_static_grads = sess.run(gradients_impl.gradients(outputs, inputs)) 77 fused_static_wgrads = sess.run( 78 gradients_impl.gradients(outputs, fused_static_vars)) 79 80 self.assertAllClose(basic_outputs, fused_static_outputs) 81 self.assertAllClose(basic_state, fused_static_state) 82 self.assertAllClose(basic_grads, fused_static_grads) 83 for basic, fused in zip(basic_wgrads, fused_static_wgrads): 84 self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2) 85 86 with variable_scope.variable_scope( 87 "fused_dynamic", initializer=initializer): 88 fused_cell = fused_rnn_cell.FusedRNNCellAdaptor( 89 rnn_cell.BasicRNNCell(10), use_dynamic_rnn=True) 90 outputs, state = fused_cell(inputs, dtype=dtypes.float64) 91 fused_dynamic_vars = [ 92 v for v in variables.trainable_variables() 93 if v.name.startswith("fused_dynamic/") 94 ] 95 sess.run([variables.global_variables_initializer()]) 96 fused_dynamic_outputs, fused_dynamic_state = sess.run([outputs, state]) 97 fused_dynamic_grads = sess.run( 98 gradients_impl.gradients(outputs, inputs)) 99 fused_dynamic_wgrads = sess.run( 100 gradients_impl.gradients(outputs, fused_dynamic_vars)) 101 102 self.assertAllClose(basic_outputs, fused_dynamic_outputs) 103 self.assertAllClose(basic_state, fused_dynamic_state) 104 self.assertAllClose(basic_grads, fused_dynamic_grads) 105 for basic, fused in zip(basic_wgrads, fused_dynamic_wgrads): 106 self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2) 107 108 def testTimeReversedFusedRNN(self): 109 with self.test_session() as sess: 110 initializer = init_ops.random_uniform_initializer( 111 -0.01, 0.01, seed=19890213) 112 fw_cell = rnn_cell.BasicRNNCell(10) 113 bw_cell = rnn_cell.BasicRNNCell(10) 114 batch_size = 5 115 input_size = 20 116 timelen = 15 117 inputs = constant_op.constant( 118 np.random.randn(timelen, batch_size, input_size)) 119 120 # test bi-directional rnn 121 with variable_scope.variable_scope("basic", initializer=initializer): 122 unpacked_inputs = array_ops.unstack(inputs) 123 outputs, fw_state, bw_state = rnn.static_bidirectional_rnn( 124 fw_cell, bw_cell, unpacked_inputs, dtype=dtypes.float64) 125 packed_outputs = array_ops.stack(outputs) 126 basic_vars = [ 127 v for v in variables.trainable_variables() 128 if v.name.startswith("basic/") 129 ] 130 sess.run([variables.global_variables_initializer()]) 131 basic_outputs, basic_fw_state, basic_bw_state = sess.run( 132 [packed_outputs, fw_state, bw_state]) 133 basic_grads = sess.run(gradients_impl.gradients(packed_outputs, inputs)) 134 basic_wgrads = sess.run( 135 gradients_impl.gradients(packed_outputs, basic_vars)) 136 137 with variable_scope.variable_scope("fused", initializer=initializer): 138 fused_cell = fused_rnn_cell.FusedRNNCellAdaptor( 139 rnn_cell.BasicRNNCell(10)) 140 fused_bw_cell = fused_rnn_cell.TimeReversedFusedRNN( 141 fused_rnn_cell.FusedRNNCellAdaptor(rnn_cell.BasicRNNCell(10))) 142 fw_outputs, fw_state = fused_cell( 143 inputs, dtype=dtypes.float64, scope="fw") 144 bw_outputs, bw_state = fused_bw_cell( 145 inputs, dtype=dtypes.float64, scope="bw") 146 outputs = array_ops.concat([fw_outputs, bw_outputs], 2) 147 fused_vars = [ 148 v for v in variables.trainable_variables() 149 if v.name.startswith("fused/") 150 ] 151 sess.run([variables.global_variables_initializer()]) 152 fused_outputs, fused_fw_state, fused_bw_state = sess.run( 153 [outputs, fw_state, bw_state]) 154 fused_grads = sess.run(gradients_impl.gradients(outputs, inputs)) 155 fused_wgrads = sess.run(gradients_impl.gradients(outputs, fused_vars)) 156 157 self.assertAllClose(basic_outputs, fused_outputs) 158 self.assertAllClose(basic_fw_state, fused_fw_state) 159 self.assertAllClose(basic_bw_state, fused_bw_state) 160 self.assertAllClose(basic_grads, fused_grads) 161 for basic, fused in zip(basic_wgrads, fused_wgrads): 162 self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2) 163 164 165 if __name__ == "__main__": 166 test.main() 167