Home | History | Annotate | Download | only in tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tensorflow.ops.reverse_sequence_op."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.compiler.tests.xla_test import XLATestCase
     24 from tensorflow.python.framework import dtypes
     25 from tensorflow.python.ops import array_ops
     26 from tensorflow.python.platform import test
     27 
     28 
     29 class ReverseSequenceTest(XLATestCase):
     30 
     31   def _testReverseSequence(self,
     32                            x,
     33                            batch_axis,
     34                            seq_axis,
     35                            seq_lengths,
     36                            truth,
     37                            expected_err_re=None):
     38     with self.test_session():
     39       p = array_ops.placeholder(dtypes.as_dtype(x.dtype))
     40       lengths = array_ops.placeholder(dtypes.as_dtype(seq_lengths.dtype))
     41       with self.test_scope():
     42         ans = array_ops.reverse_sequence(
     43             p, batch_axis=batch_axis, seq_axis=seq_axis, seq_lengths=lengths)
     44       if expected_err_re is None:
     45         tf_ans = ans.eval(feed_dict={p: x, lengths: seq_lengths})
     46         self.assertAllClose(tf_ans, truth, atol=1e-10)
     47       else:
     48         with self.assertRaisesOpError(expected_err_re):
     49           ans.eval(feed_dict={p: x, lengths: seq_lengths})
     50 
     51   def testSimple(self):
     52     x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
     53     expected = np.array([[1, 2, 3], [6, 5, 4], [8, 7, 9]], dtype=np.int32)
     54     self._testReverseSequence(
     55         x,
     56         batch_axis=0,
     57         seq_axis=1,
     58         seq_lengths=np.array([1, 3, 2], np.int32),
     59         truth=expected)
     60 
     61   def _testBasic(self, dtype, len_dtype):
     62     x = np.asarray(
     63         [[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], [13, 14, 15, 16]],
     64          [[17, 18, 19, 20], [21, 22, 23, 24]]],
     65         dtype=dtype)
     66     x = x.reshape(3, 2, 4, 1, 1)
     67     x = x.transpose([2, 1, 0, 3, 4])  # permute axes 0 <=> 2
     68 
     69     # reverse dim 2 up to (0:3, none, 0:4) along dim=0
     70     seq_lengths = np.asarray([3, 0, 4], dtype=len_dtype)
     71 
     72     truth_orig = np.asarray(
     73         [
     74             [[3, 2, 1, 4], [7, 6, 5, 8]],  # reverse 0:3
     75             [[9, 10, 11, 12], [13, 14, 15, 16]],  # reverse none
     76             [[20, 19, 18, 17], [24, 23, 22, 21]]
     77         ],  # reverse 0:4 (all)
     78         dtype=dtype)
     79     truth_orig = truth_orig.reshape(3, 2, 4, 1, 1)
     80     truth = truth_orig.transpose([2, 1, 0, 3, 4])  # permute axes 0 <=> 2
     81 
     82     seq_axis = 0  # permute seq_axis and batch_axis (originally 2 and 0, resp.)
     83     batch_axis = 2
     84     self._testReverseSequence(x, batch_axis, seq_axis, seq_lengths, truth)
     85 
     86   def testSeqLength(self):
     87     for dtype in self.all_types:
     88       for seq_dtype in self.int_types:
     89         self._testBasic(dtype, seq_dtype)
     90 
     91 
     92 if __name__ == "__main__":
     93   test.main()
     94