Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Tests for contrib.seq2seq.python.seq2seq.loss_ops."""
     17 # pylint: disable=unused-import,g-bad-import-order
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 # pylint: enable=unused-import
     22 
     23 import numpy as np
     24 
     25 from tensorflow.contrib.seq2seq.python.ops import loss
     26 from tensorflow.python.framework import constant_op
     27 from tensorflow.python.framework import dtypes
     28 from tensorflow.python.ops import array_ops
     29 from tensorflow.python.ops import init_ops
     30 from tensorflow.python.ops import variable_scope
     31 from tensorflow.python.platform import test
     32 
     33 
     34 class LossTest(test.TestCase):
     35 
     36   def testSequenceLoss(self):
     37     with self.test_session(use_gpu=True) as sess:
     38       with variable_scope.variable_scope(
     39           'root', initializer=init_ops.constant_initializer(0.5)):
     40         batch_size = 2
     41         sequence_length = 3
     42         number_of_classes = 5
     43         logits = [
     44             constant_op.constant(
     45                 i + 0.5, shape=[batch_size, number_of_classes])
     46             for i in range(sequence_length)
     47         ]
     48         logits = array_ops.stack(logits, axis=1)
     49         targets = [
     50             constant_op.constant(
     51                 i, dtypes.int32, shape=[batch_size])
     52             for i in range(sequence_length)
     53         ]
     54         targets = array_ops.stack(targets, axis=1)
     55         weights = [
     56             constant_op.constant(
     57                 1.0, shape=[batch_size]) for i in range(sequence_length)
     58         ]
     59         weights = array_ops.stack(weights, axis=1)
     60 
     61         average_loss_per_example = loss.sequence_loss(
     62             logits, targets, weights,
     63             average_across_timesteps=True,
     64             average_across_batch=True)
     65         res = sess.run(average_loss_per_example)
     66         self.assertAllClose(1.60944, res)
     67 
     68         average_loss_per_sequence = loss.sequence_loss(
     69             logits, targets, weights,
     70             average_across_timesteps=False,
     71             average_across_batch=True)
     72         res = sess.run(average_loss_per_sequence)
     73         compare_per_sequence = np.ones((sequence_length)) * 1.60944
     74         self.assertAllClose(compare_per_sequence, res)
     75 
     76         average_loss_per_batch = loss.sequence_loss(
     77             logits, targets, weights,
     78             average_across_timesteps=True,
     79             average_across_batch=False)
     80         res = sess.run(average_loss_per_batch)
     81         compare_per_batch = np.ones((batch_size)) * 1.60944
     82         self.assertAllClose(compare_per_batch, res)
     83 
     84         total_loss = loss.sequence_loss(
     85             logits, targets, weights,
     86             average_across_timesteps=False,
     87             average_across_batch=False)
     88         res = sess.run(total_loss)
     89         compare_total = np.ones((batch_size, sequence_length)) * 1.60944
     90         self.assertAllClose(compare_total, res)
     91 
     92 if __name__ == '__main__':
     93   test.main()
     94