1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tensorflow.ops.reverse_sequence_op.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.compiler.tests.xla_test import XLATestCase 24 from tensorflow.python.framework import dtypes 25 from tensorflow.python.ops import array_ops 26 from tensorflow.python.platform import test 27 28 29 class ReverseSequenceTest(XLATestCase): 30 31 def _testReverseSequence(self, 32 x, 33 batch_axis, 34 seq_axis, 35 seq_lengths, 36 truth, 37 expected_err_re=None): 38 with self.test_session(): 39 p = array_ops.placeholder(dtypes.as_dtype(x.dtype)) 40 lengths = array_ops.placeholder(dtypes.as_dtype(seq_lengths.dtype)) 41 with self.test_scope(): 42 ans = array_ops.reverse_sequence( 43 p, batch_axis=batch_axis, seq_axis=seq_axis, seq_lengths=lengths) 44 if expected_err_re is None: 45 tf_ans = ans.eval(feed_dict={p: x, lengths: seq_lengths}) 46 self.assertAllClose(tf_ans, truth, atol=1e-10) 47 else: 48 with self.assertRaisesOpError(expected_err_re): 49 ans.eval(feed_dict={p: x, lengths: seq_lengths}) 50 51 def testSimple(self): 52 x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32) 53 expected = np.array([[1, 2, 3], [6, 5, 4], [8, 7, 9]], dtype=np.int32) 54 self._testReverseSequence( 55 x, 56 batch_axis=0, 57 seq_axis=1, 58 seq_lengths=np.array([1, 3, 2], np.int32), 59 truth=expected) 60 61 def _testBasic(self, dtype, len_dtype): 62 x = np.asarray( 63 [[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], [13, 14, 15, 16]], 64 [[17, 18, 19, 20], [21, 22, 23, 24]]], 65 dtype=dtype) 66 x = x.reshape(3, 2, 4, 1, 1) 67 x = x.transpose([2, 1, 0, 3, 4]) # permute axes 0 <=> 2 68 69 # reverse dim 2 up to (0:3, none, 0:4) along dim=0 70 seq_lengths = np.asarray([3, 0, 4], dtype=len_dtype) 71 72 truth_orig = np.asarray( 73 [ 74 [[3, 2, 1, 4], [7, 6, 5, 8]], # reverse 0:3 75 [[9, 10, 11, 12], [13, 14, 15, 16]], # reverse none 76 [[20, 19, 18, 17], [24, 23, 22, 21]] 77 ], # reverse 0:4 (all) 78 dtype=dtype) 79 truth_orig = truth_orig.reshape(3, 2, 4, 1, 1) 80 truth = truth_orig.transpose([2, 1, 0, 3, 4]) # permute axes 0 <=> 2 81 82 seq_axis = 0 # permute seq_axis and batch_axis (originally 2 and 0, resp.) 83 batch_axis = 2 84 self._testReverseSequence(x, batch_axis, seq_axis, seq_lengths, truth) 85 86 def testSeqLength(self): 87 for dtype in self.all_types: 88 for seq_dtype in self.int_types: 89 self._testBasic(dtype, seq_dtype) 90 91 92 if __name__ == "__main__": 93 test.main() 94