1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 """Tests for contrib.seq2seq.python.seq2seq.loss_ops.""" 17 # pylint: disable=unused-import,g-bad-import-order 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 # pylint: enable=unused-import 22 23 import numpy as np 24 25 from tensorflow.contrib.seq2seq.python.ops import loss 26 from tensorflow.python.framework import constant_op 27 from tensorflow.python.framework import dtypes 28 from tensorflow.python.ops import array_ops 29 from tensorflow.python.ops import init_ops 30 from tensorflow.python.ops import variable_scope 31 from tensorflow.python.platform import test 32 33 34 class LossTest(test.TestCase): 35 36 def testSequenceLoss(self): 37 with self.test_session(use_gpu=True) as sess: 38 with variable_scope.variable_scope( 39 'root', initializer=init_ops.constant_initializer(0.5)): 40 batch_size = 2 41 sequence_length = 3 42 number_of_classes = 5 43 logits = [ 44 constant_op.constant( 45 i + 0.5, shape=[batch_size, number_of_classes]) 46 for i in range(sequence_length) 47 ] 48 logits = array_ops.stack(logits, axis=1) 49 targets = [ 50 constant_op.constant( 51 i, dtypes.int32, shape=[batch_size]) 52 for i in range(sequence_length) 53 ] 54 targets = array_ops.stack(targets, axis=1) 55 weights = [ 56 constant_op.constant( 57 1.0, shape=[batch_size]) for i in range(sequence_length) 58 ] 59 weights = array_ops.stack(weights, axis=1) 60 61 average_loss_per_example = loss.sequence_loss( 62 logits, targets, weights, 63 average_across_timesteps=True, 64 average_across_batch=True) 65 res = sess.run(average_loss_per_example) 66 self.assertAllClose(1.60944, res) 67 68 average_loss_per_sequence = loss.sequence_loss( 69 logits, targets, weights, 70 average_across_timesteps=False, 71 average_across_batch=True) 72 res = sess.run(average_loss_per_sequence) 73 compare_per_sequence = np.ones((sequence_length)) * 1.60944 74 self.assertAllClose(compare_per_sequence, res) 75 76 average_loss_per_batch = loss.sequence_loss( 77 logits, targets, weights, 78 average_across_timesteps=True, 79 average_across_batch=False) 80 res = sess.run(average_loss_per_batch) 81 compare_per_batch = np.ones((batch_size)) * 1.60944 82 self.assertAllClose(compare_per_batch, res) 83 84 total_loss = loss.sequence_loss( 85 logits, targets, weights, 86 average_across_timesteps=False, 87 average_across_batch=False) 88 res = sess.run(total_loss) 89 compare_total = np.ones((batch_size, sequence_length)) * 1.60944 90 self.assertAllClose(compare_total, res) 91 92 if __name__ == '__main__': 93 test.main() 94