1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for CandidateSamplerOp.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.python.framework import constant_op 24 from tensorflow.python.framework import dtypes 25 from tensorflow.python.ops import array_ops 26 from tensorflow.python.ops import candidate_sampling_ops 27 from tensorflow.python.ops import math_ops 28 from tensorflow.python.platform import test 29 30 31 class RangeSamplerOpsTest(test.TestCase): 32 33 BATCH_SIZE = 3 34 NUM_TRUE = 2 35 RANGE = 5 36 NUM_SAMPLED = RANGE 37 38 TRUE_LABELS = [[1, 2], [0, 4], [3, 3]] 39 40 def testTrueCandidates(self): 41 with self.test_session() as sess: 42 indices = constant_op.constant([0, 0, 1, 1, 2, 2]) 43 true_candidates_vec = constant_op.constant([1, 2, 0, 4, 3, 3]) 44 true_candidates_matrix = array_ops.reshape( 45 true_candidates_vec, [self.BATCH_SIZE, self.NUM_TRUE]) 46 indices_val, true_candidates_val = sess.run( 47 [indices, true_candidates_matrix]) 48 49 self.assertAllEqual(indices_val, [0, 0, 1, 1, 2, 2]) 50 self.assertAllEqual(true_candidates_val, self.TRUE_LABELS) 51 52 def testSampledCandidates(self): 53 with self.test_session(): 54 true_classes = constant_op.constant( 55 [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64) 56 sampled_candidates, _, _ = candidate_sampling_ops.all_candidate_sampler( 57 true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True) 58 result = sampled_candidates.eval() 59 60 expected_ids = [0, 1, 2, 3, 4] 61 self.assertAllEqual(result, expected_ids) 62 self.assertEqual(sampled_candidates.get_shape(), [self.NUM_SAMPLED]) 63 64 def testTrueLogExpectedCount(self): 65 with self.test_session(): 66 true_classes = constant_op.constant( 67 [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64) 68 _, true_expected_count, _ = candidate_sampling_ops.all_candidate_sampler( 69 true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True) 70 true_log_expected_count = math_ops.log(true_expected_count) 71 result = true_log_expected_count.eval() 72 73 self.assertAllEqual(result, [[0.0] * self.NUM_TRUE] * self.BATCH_SIZE) 74 self.assertEqual(true_expected_count.get_shape(), 75 [self.BATCH_SIZE, self.NUM_TRUE]) 76 self.assertEqual(true_log_expected_count.get_shape(), 77 [self.BATCH_SIZE, self.NUM_TRUE]) 78 79 def testSampledLogExpectedCount(self): 80 with self.test_session(): 81 true_classes = constant_op.constant( 82 [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64) 83 _, _, sampled_expected_count = candidate_sampling_ops.all_candidate_sampler( # pylint: disable=line-too-long 84 true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True) 85 sampled_log_expected_count = math_ops.log(sampled_expected_count) 86 result = sampled_log_expected_count.eval() 87 88 self.assertAllEqual(result, [0.0] * self.NUM_SAMPLED) 89 self.assertEqual(sampled_expected_count.get_shape(), [self.NUM_SAMPLED]) 90 self.assertEqual(sampled_log_expected_count.get_shape(), [self.NUM_SAMPLED]) 91 92 def testAccidentalHits(self): 93 with self.test_session() as sess: 94 true_classes = constant_op.constant( 95 [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64) 96 sampled_candidates, _, _ = candidate_sampling_ops.all_candidate_sampler( 97 true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True) 98 accidental_hits = candidate_sampling_ops.compute_accidental_hits( 99 true_classes, sampled_candidates, self.NUM_TRUE) 100 indices, ids, weights = sess.run(accidental_hits) 101 102 self.assertEqual(1, accidental_hits[0].get_shape().ndims) 103 self.assertEqual(1, accidental_hits[1].get_shape().ndims) 104 self.assertEqual(1, accidental_hits[2].get_shape().ndims) 105 for index, id_, weight in zip(indices, ids, weights): 106 self.assertTrue(id_ in self.TRUE_LABELS[index]) 107 self.assertLess(weight, -1.0e37) 108 109 def testSeed(self): 110 111 def draw(seed): 112 with self.test_session(): 113 true_classes = constant_op.constant( 114 [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64) 115 sampled, _, _ = candidate_sampling_ops.log_uniform_candidate_sampler( 116 true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True, 5, seed=seed) 117 return sampled.eval() 118 119 # Non-zero seed. Repeatable. 120 for seed in [1, 12, 123, 1234]: 121 self.assertAllEqual(draw(seed), draw(seed)) 122 # Seed=0 means random seeds. 123 num_same = 0 124 for _ in range(10): 125 if np.allclose(draw(None), draw(None)): 126 num_same += 1 127 # Accounts for the fact that the same random seed may be picked 128 # twice very rarely. 129 self.assertLessEqual(num_same, 2) 130 131 132 if __name__ == "__main__": 133 test.main() 134