Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tensorflow.contrib.rnn.python.ops.fused_rnn_cell."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.contrib.rnn.python.ops import fused_rnn_cell
     24 from tensorflow.python.framework import constant_op
     25 from tensorflow.python.framework import dtypes
     26 from tensorflow.python.ops import array_ops
     27 from tensorflow.python.ops import gradients_impl
     28 from tensorflow.python.ops import init_ops
     29 from tensorflow.python.ops import rnn
     30 from tensorflow.python.ops import rnn_cell
     31 from tensorflow.python.ops import variable_scope
     32 from tensorflow.python.ops import variables
     33 from tensorflow.python.platform import test
     34 
     35 
     36 class FusedRnnCellTest(test.TestCase):
     37 
     38   def testBasicRNNFusedWrapper(self):
     39     """This test checks that using a wrapper for BasicRNN works as expected."""
     40 
     41     with self.test_session() as sess:
     42       initializer = init_ops.random_uniform_initializer(
     43           -0.01, 0.01, seed=19890212)
     44       cell = rnn_cell.BasicRNNCell(10)
     45       batch_size = 5
     46       input_size = 20
     47       timelen = 15
     48       inputs = constant_op.constant(
     49           np.random.randn(timelen, batch_size, input_size))
     50       with variable_scope.variable_scope("basic", initializer=initializer):
     51         unpacked_inputs = array_ops.unstack(inputs)
     52         outputs, state = rnn.static_rnn(
     53             cell, unpacked_inputs, dtype=dtypes.float64)
     54         packed_outputs = array_ops.stack(outputs)
     55         basic_vars = [
     56             v for v in variables.trainable_variables()
     57             if v.name.startswith("basic/")
     58         ]
     59         sess.run([variables.global_variables_initializer()])
     60         basic_outputs, basic_state = sess.run([packed_outputs, state])
     61         basic_grads = sess.run(gradients_impl.gradients(packed_outputs, inputs))
     62         basic_wgrads = sess.run(
     63             gradients_impl.gradients(packed_outputs, basic_vars))
     64 
     65       with variable_scope.variable_scope(
     66           "fused_static", initializer=initializer):
     67         fused_cell = fused_rnn_cell.FusedRNNCellAdaptor(
     68             rnn_cell.BasicRNNCell(10))
     69         outputs, state = fused_cell(inputs, dtype=dtypes.float64)
     70         fused_static_vars = [
     71             v for v in variables.trainable_variables()
     72             if v.name.startswith("fused_static/")
     73         ]
     74         sess.run([variables.global_variables_initializer()])
     75         fused_static_outputs, fused_static_state = sess.run([outputs, state])
     76         fused_static_grads = sess.run(gradients_impl.gradients(outputs, inputs))
     77         fused_static_wgrads = sess.run(
     78             gradients_impl.gradients(outputs, fused_static_vars))
     79 
     80       self.assertAllClose(basic_outputs, fused_static_outputs)
     81       self.assertAllClose(basic_state, fused_static_state)
     82       self.assertAllClose(basic_grads, fused_static_grads)
     83       for basic, fused in zip(basic_wgrads, fused_static_wgrads):
     84         self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2)
     85 
     86       with variable_scope.variable_scope(
     87           "fused_dynamic", initializer=initializer):
     88         fused_cell = fused_rnn_cell.FusedRNNCellAdaptor(
     89             rnn_cell.BasicRNNCell(10), use_dynamic_rnn=True)
     90         outputs, state = fused_cell(inputs, dtype=dtypes.float64)
     91         fused_dynamic_vars = [
     92             v for v in variables.trainable_variables()
     93             if v.name.startswith("fused_dynamic/")
     94         ]
     95         sess.run([variables.global_variables_initializer()])
     96         fused_dynamic_outputs, fused_dynamic_state = sess.run([outputs, state])
     97         fused_dynamic_grads = sess.run(
     98             gradients_impl.gradients(outputs, inputs))
     99         fused_dynamic_wgrads = sess.run(
    100             gradients_impl.gradients(outputs, fused_dynamic_vars))
    101 
    102       self.assertAllClose(basic_outputs, fused_dynamic_outputs)
    103       self.assertAllClose(basic_state, fused_dynamic_state)
    104       self.assertAllClose(basic_grads, fused_dynamic_grads)
    105       for basic, fused in zip(basic_wgrads, fused_dynamic_wgrads):
    106         self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2)
    107 
    108   def testTimeReversedFusedRNN(self):
    109     with self.test_session() as sess:
    110       initializer = init_ops.random_uniform_initializer(
    111           -0.01, 0.01, seed=19890213)
    112       fw_cell = rnn_cell.BasicRNNCell(10)
    113       bw_cell = rnn_cell.BasicRNNCell(10)
    114       batch_size = 5
    115       input_size = 20
    116       timelen = 15
    117       inputs = constant_op.constant(
    118           np.random.randn(timelen, batch_size, input_size))
    119 
    120       # test bi-directional rnn
    121       with variable_scope.variable_scope("basic", initializer=initializer):
    122         unpacked_inputs = array_ops.unstack(inputs)
    123         outputs, fw_state, bw_state = rnn.static_bidirectional_rnn(
    124             fw_cell, bw_cell, unpacked_inputs, dtype=dtypes.float64)
    125         packed_outputs = array_ops.stack(outputs)
    126         basic_vars = [
    127             v for v in variables.trainable_variables()
    128             if v.name.startswith("basic/")
    129         ]
    130         sess.run([variables.global_variables_initializer()])
    131         basic_outputs, basic_fw_state, basic_bw_state = sess.run(
    132             [packed_outputs, fw_state, bw_state])
    133         basic_grads = sess.run(gradients_impl.gradients(packed_outputs, inputs))
    134         basic_wgrads = sess.run(
    135             gradients_impl.gradients(packed_outputs, basic_vars))
    136 
    137       with variable_scope.variable_scope("fused", initializer=initializer):
    138         fused_cell = fused_rnn_cell.FusedRNNCellAdaptor(
    139             rnn_cell.BasicRNNCell(10))
    140         fused_bw_cell = fused_rnn_cell.TimeReversedFusedRNN(
    141             fused_rnn_cell.FusedRNNCellAdaptor(rnn_cell.BasicRNNCell(10)))
    142         fw_outputs, fw_state = fused_cell(
    143             inputs, dtype=dtypes.float64, scope="fw")
    144         bw_outputs, bw_state = fused_bw_cell(
    145             inputs, dtype=dtypes.float64, scope="bw")
    146         outputs = array_ops.concat([fw_outputs, bw_outputs], 2)
    147         fused_vars = [
    148             v for v in variables.trainable_variables()
    149             if v.name.startswith("fused/")
    150         ]
    151         sess.run([variables.global_variables_initializer()])
    152         fused_outputs, fused_fw_state, fused_bw_state = sess.run(
    153             [outputs, fw_state, bw_state])
    154         fused_grads = sess.run(gradients_impl.gradients(outputs, inputs))
    155         fused_wgrads = sess.run(gradients_impl.gradients(outputs, fused_vars))
    156 
    157       self.assertAllClose(basic_outputs, fused_outputs)
    158       self.assertAllClose(basic_fw_state, fused_fw_state)
    159       self.assertAllClose(basic_bw_state, fused_bw_state)
    160       self.assertAllClose(basic_grads, fused_grads)
    161       for basic, fused in zip(basic_wgrads, fused_wgrads):
    162         self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2)
    163 
    164 
    165 if __name__ == "__main__":
    166   test.main()
    167