Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for CandidateSamplerOp."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.framework import constant_op
     24 from tensorflow.python.framework import dtypes
     25 from tensorflow.python.ops import array_ops
     26 from tensorflow.python.ops import candidate_sampling_ops
     27 from tensorflow.python.ops import math_ops
     28 from tensorflow.python.platform import test
     29 
     30 
     31 class RangeSamplerOpsTest(test.TestCase):
     32 
     33   BATCH_SIZE = 3
     34   NUM_TRUE = 2
     35   RANGE = 5
     36   NUM_SAMPLED = RANGE
     37 
     38   TRUE_LABELS = [[1, 2], [0, 4], [3, 3]]
     39 
     40   def testTrueCandidates(self):
     41     with self.test_session() as sess:
     42       indices = constant_op.constant([0, 0, 1, 1, 2, 2])
     43       true_candidates_vec = constant_op.constant([1, 2, 0, 4, 3, 3])
     44       true_candidates_matrix = array_ops.reshape(
     45           true_candidates_vec, [self.BATCH_SIZE, self.NUM_TRUE])
     46       indices_val, true_candidates_val = sess.run(
     47           [indices, true_candidates_matrix])
     48 
     49     self.assertAllEqual(indices_val, [0, 0, 1, 1, 2, 2])
     50     self.assertAllEqual(true_candidates_val, self.TRUE_LABELS)
     51 
     52   def testSampledCandidates(self):
     53     with self.test_session():
     54       true_classes = constant_op.constant(
     55           [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
     56       sampled_candidates, _, _ = candidate_sampling_ops.all_candidate_sampler(
     57           true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True)
     58       result = sampled_candidates.eval()
     59 
     60     expected_ids = [0, 1, 2, 3, 4]
     61     self.assertAllEqual(result, expected_ids)
     62     self.assertEqual(sampled_candidates.get_shape(), [self.NUM_SAMPLED])
     63 
     64   def testTrueLogExpectedCount(self):
     65     with self.test_session():
     66       true_classes = constant_op.constant(
     67           [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
     68       _, true_expected_count, _ = candidate_sampling_ops.all_candidate_sampler(
     69           true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True)
     70       true_log_expected_count = math_ops.log(true_expected_count)
     71       result = true_log_expected_count.eval()
     72 
     73     self.assertAllEqual(result, [[0.0] * self.NUM_TRUE] * self.BATCH_SIZE)
     74     self.assertEqual(true_expected_count.get_shape(),
     75                      [self.BATCH_SIZE, self.NUM_TRUE])
     76     self.assertEqual(true_log_expected_count.get_shape(),
     77                      [self.BATCH_SIZE, self.NUM_TRUE])
     78 
     79   def testSampledLogExpectedCount(self):
     80     with self.test_session():
     81       true_classes = constant_op.constant(
     82           [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
     83       _, _, sampled_expected_count = candidate_sampling_ops.all_candidate_sampler(  # pylint: disable=line-too-long
     84           true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True)
     85       sampled_log_expected_count = math_ops.log(sampled_expected_count)
     86       result = sampled_log_expected_count.eval()
     87 
     88     self.assertAllEqual(result, [0.0] * self.NUM_SAMPLED)
     89     self.assertEqual(sampled_expected_count.get_shape(), [self.NUM_SAMPLED])
     90     self.assertEqual(sampled_log_expected_count.get_shape(), [self.NUM_SAMPLED])
     91 
     92   def testAccidentalHits(self):
     93     with self.test_session() as sess:
     94       true_classes = constant_op.constant(
     95           [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
     96       sampled_candidates, _, _ = candidate_sampling_ops.all_candidate_sampler(
     97           true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True)
     98       accidental_hits = candidate_sampling_ops.compute_accidental_hits(
     99           true_classes, sampled_candidates, self.NUM_TRUE)
    100       indices, ids, weights = sess.run(accidental_hits)
    101 
    102     self.assertEqual(1, accidental_hits[0].get_shape().ndims)
    103     self.assertEqual(1, accidental_hits[1].get_shape().ndims)
    104     self.assertEqual(1, accidental_hits[2].get_shape().ndims)
    105     for index, id_, weight in zip(indices, ids, weights):
    106       self.assertTrue(id_ in self.TRUE_LABELS[index])
    107       self.assertLess(weight, -1.0e37)
    108 
    109   def testSeed(self):
    110 
    111     def draw(seed):
    112       with self.test_session():
    113         true_classes = constant_op.constant(
    114             [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
    115         sampled, _, _ = candidate_sampling_ops.log_uniform_candidate_sampler(
    116             true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True, 5, seed=seed)
    117         return sampled.eval()
    118 
    119     # Non-zero seed. Repeatable.
    120     for seed in [1, 12, 123, 1234]:
    121       self.assertAllEqual(draw(seed), draw(seed))
    122     # Seed=0 means random seeds.
    123     num_same = 0
    124     for _ in range(10):
    125       if np.allclose(draw(None), draw(None)):
    126         num_same += 1
    127     # Accounts for the fact that the same random seed may be picked
    128     # twice very rarely.
    129     self.assertLessEqual(num_same, 2)
    130 
    131 
    132 if __name__ == "__main__":
    133   test.main()
    134