Home | History | Annotate | Download | only in estimators
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for layers.rnn_common."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.contrib.learn.python.learn.estimators import rnn_common
     24 from tensorflow.python.client import session
     25 from tensorflow.python.framework import constant_op
     26 from tensorflow.python.framework import dtypes
     27 from tensorflow.python.platform import test
     28 
     29 
     30 class RnnCommonTest(test.TestCase):
     31 
     32   def testMaskActivationsAndLabels(self):
     33     """Test `mask_activations_and_labels`."""
     34     batch_size = 4
     35     padded_length = 6
     36     num_classes = 4
     37     np.random.seed(1234)
     38     sequence_length = np.random.randint(0, padded_length + 1, batch_size)
     39     activations = np.random.rand(batch_size, padded_length, num_classes)
     40     labels = np.random.randint(0, num_classes, [batch_size, padded_length])
     41     (activations_masked_t,
     42      labels_masked_t) = rnn_common.mask_activations_and_labels(
     43          constant_op.constant(activations, dtype=dtypes.float32),
     44          constant_op.constant(labels, dtype=dtypes.int32),
     45          constant_op.constant(sequence_length, dtype=dtypes.int32))
     46 
     47     with self.cached_session() as sess:
     48       activations_masked, labels_masked = sess.run(
     49           [activations_masked_t, labels_masked_t])
     50 
     51     expected_activations_shape = [sum(sequence_length), num_classes]
     52     np.testing.assert_equal(
     53         expected_activations_shape, activations_masked.shape,
     54         'Wrong activations shape. Expected {}; got {}.'.format(
     55             expected_activations_shape, activations_masked.shape))
     56 
     57     expected_labels_shape = [sum(sequence_length)]
     58     np.testing.assert_equal(expected_labels_shape, labels_masked.shape,
     59                             'Wrong labels shape. Expected {}; got {}.'.format(
     60                                 expected_labels_shape, labels_masked.shape))
     61     masked_index = 0
     62     for i in range(batch_size):
     63       for j in range(sequence_length[i]):
     64         actual_activations = activations_masked[masked_index]
     65         expected_activations = activations[i, j, :]
     66         np.testing.assert_almost_equal(
     67             expected_activations,
     68             actual_activations,
     69             err_msg='Unexpected logit value at index [{}, {}, :].'
     70             '  Expected {}; got {}.'.format(i, j, expected_activations,
     71                                             actual_activations))
     72 
     73         actual_labels = labels_masked[masked_index]
     74         expected_labels = labels[i, j]
     75         np.testing.assert_almost_equal(
     76             expected_labels,
     77             actual_labels,
     78             err_msg='Unexpected logit value at index [{}, {}].'
     79             ' Expected {}; got {}.'.format(i, j, expected_labels,
     80                                            actual_labels))
     81         masked_index += 1
     82 
     83   def testSelectLastActivations(self):
     84     """Test `select_last_activations`."""
     85     batch_size = 4
     86     padded_length = 6
     87     num_classes = 4
     88     np.random.seed(4444)
     89     sequence_length = np.random.randint(0, padded_length + 1, batch_size)
     90     activations = np.random.rand(batch_size, padded_length, num_classes)
     91     last_activations_t = rnn_common.select_last_activations(
     92         constant_op.constant(activations, dtype=dtypes.float32),
     93         constant_op.constant(sequence_length, dtype=dtypes.int32))
     94 
     95     with session.Session() as sess:
     96       last_activations = sess.run(last_activations_t)
     97 
     98     expected_activations_shape = [batch_size, num_classes]
     99     np.testing.assert_equal(
    100         expected_activations_shape, last_activations.shape,
    101         'Wrong activations shape. Expected {}; got {}.'.format(
    102             expected_activations_shape, last_activations.shape))
    103 
    104     for i in range(batch_size):
    105       actual_activations = last_activations[i, :]
    106       expected_activations = activations[i, sequence_length[i] - 1, :]
    107       np.testing.assert_almost_equal(
    108           expected_activations,
    109           actual_activations,
    110           err_msg='Unexpected logit value at index [{}, :].'
    111           '  Expected {}; got {}.'.format(i, expected_activations,
    112                                           actual_activations))
    113 
    114 
    115 if __name__ == '__main__':
    116   test.main()
    117