Home | History | Annotate | Download | only in training
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import collections
     21 import math
     22 
     23 import numpy
     24 from tensorflow.contrib.training.python.training import resample
     25 from tensorflow.python.framework import constant_op
     26 from tensorflow.python.framework import dtypes
     27 from tensorflow.python.ops import array_ops
     28 from tensorflow.python.ops import control_flow_ops
     29 from tensorflow.python.ops import math_ops
     30 from tensorflow.python.ops import variables
     31 from tensorflow.python.platform import test
     32 
     33 
     34 class ResampleTest(test.TestCase):
     35   """Tests that resampling runs and outputs are close to expected values."""
     36 
     37   def testRepeatRange(self):
     38     cases = [
     39         ([], []),
     40         ([0], []),
     41         ([1], [0]),
     42         ([1, 0], [0]),
     43         ([0, 1], [1]),
     44         ([3], [0, 0, 0]),
     45         ([0, 1, 2, 3], [1, 2, 2, 3, 3, 3]),
     46     ]
     47     with self.test_session() as sess:
     48       for inputs, expected in cases:
     49         array_inputs = numpy.array(inputs, dtype=numpy.int32)
     50         actual = sess.run(resample._repeat_range(array_inputs))
     51         self.assertAllEqual(actual, expected)
     52 
     53   def testRoundtrip(self, rate=0.25, count=5, n=500):
     54     """Tests `resample(x, weights)` and resample(resample(x, rate), 1/rate)`."""
     55 
     56     foo = self.get_values(count)
     57     bar = self.get_values(count)
     58     weights = self.get_weights(count)
     59 
     60     resampled_in, rates = resample.weighted_resample(
     61         [foo, bar], constant_op.constant(weights), rate, seed=123)
     62 
     63     resampled_back_out = resample.resample_at_rate(
     64         resampled_in, 1.0 / rates, seed=456)
     65 
     66     init = control_flow_ops.group(variables.local_variables_initializer(),
     67                                   variables.global_variables_initializer())
     68     with self.test_session() as s:
     69       s.run(init)  # initialize
     70 
     71       # outputs
     72       counts_resampled = collections.Counter()
     73       counts_reresampled = collections.Counter()
     74       for _ in range(n):
     75         resampled_vs, reresampled_vs = s.run([resampled_in, resampled_back_out])
     76 
     77         self.assertAllEqual(resampled_vs[0], resampled_vs[1])
     78         self.assertAllEqual(reresampled_vs[0], reresampled_vs[1])
     79 
     80         for v in resampled_vs[0]:
     81           counts_resampled[v] += 1
     82         for v in reresampled_vs[0]:
     83           counts_reresampled[v] += 1
     84 
     85       # assert that resampling worked as expected
     86       self.assert_expected(weights, rate, counts_resampled, n)
     87 
     88       # and that re-resampling gives the approx identity.
     89       self.assert_expected(
     90           [1.0 for _ in weights],
     91           1.0,
     92           counts_reresampled,
     93           n,
     94           abs_delta=0.1 * n * count)
     95 
     96   def testCorrectRates(self, rate=0.25, count=10, n=500, rtol=0.1):
     97     """Tests that the rates returned by weighted_resample are correct."""
     98 
     99     # The approach here is to verify that:
    100     #  - sum(1/rate) approximates the size of the original collection
    101     #  - sum(1/rate * value) approximates the sum of the original inputs,
    102     #  - sum(1/rate * value)/sum(1/rate) approximates the mean.
    103     vals = self.get_values(count)
    104     weights = self.get_weights(count)
    105 
    106     resampled, rates = resample.weighted_resample([vals],
    107                                                   constant_op.constant(weights),
    108                                                   rate)
    109 
    110     invrates = 1.0 / rates
    111 
    112     init = control_flow_ops.group(variables.local_variables_initializer(),
    113                                   variables.global_variables_initializer())
    114     expected_sum_op = math_ops.reduce_sum(vals)
    115     with self.test_session() as s:
    116       s.run(init)
    117       expected_sum = n * s.run(expected_sum_op)
    118 
    119       weight_sum = 0.0
    120       weighted_value_sum = 0.0
    121       for _ in range(n):
    122         val, inv_rate = s.run([resampled[0], invrates])
    123         weight_sum += sum(inv_rate)
    124         weighted_value_sum += sum(val * inv_rate)
    125 
    126     # sum(inv_rate) ~= N*count:
    127     expected_count = count * n
    128     self.assertAlmostEqual(
    129         expected_count, weight_sum, delta=(rtol * expected_count))
    130 
    131     # sum(vals) * n ~= weighted_sum(resampled, 1.0/weights)
    132     self.assertAlmostEqual(
    133         expected_sum, weighted_value_sum, delta=(rtol * expected_sum))
    134 
    135     # Mean ~= weighted mean:
    136     expected_mean = expected_sum / float(n * count)
    137     self.assertAlmostEqual(
    138         expected_mean,
    139         weighted_value_sum / weight_sum,
    140         delta=(rtol * expected_mean))
    141 
    142   def testZeroRateUnknownShapes(self, count=10):
    143     """Tests that resampling runs with completely runtime shapes."""
    144     # Use placeholcers without shape set:
    145     vals = array_ops.placeholder(dtype=dtypes.int32)
    146     rates = array_ops.placeholder(dtype=dtypes.float32)
    147 
    148     resampled = resample.resample_at_rate([vals], rates)
    149 
    150     with self.test_session() as s:
    151       rs, = s.run(resampled, {
    152           vals: list(range(count)),
    153           rates: numpy.zeros(
    154               shape=[count], dtype=numpy.float32)
    155       })
    156       self.assertEqual(rs.shape, (0,))
    157 
    158   def testDtypes(self, count=10):
    159     """Test that we can define the ops with float64 weights."""
    160 
    161     vals = self.get_values(count)
    162     weights = math_ops.cast(self.get_weights(count), dtypes.float64)
    163 
    164     # should not error:
    165     resample.resample_at_rate([vals], weights)
    166     resample.weighted_resample(
    167         [vals], weights, overall_rate=math_ops.cast(1.0, dtypes.float64))
    168 
    169   def get_weights(self, n, mean=10.0, stddev=5):
    170     """Returns random positive weight values."""
    171     assert mean > 0, 'Weights have to be positive.'
    172     results = []
    173     while len(results) < n:
    174       v = numpy.random.normal(mean, stddev)
    175       if v > 0:
    176         results.append(v)
    177     return results
    178 
    179   def get_values(self, n):
    180     return constant_op.constant(list(range(n)))
    181 
    182   def assert_expected(self,
    183                       weights,
    184                       overall_rate,
    185                       counts,
    186                       n,
    187                       tol=2.0,
    188                       abs_delta=0):
    189     # Overall, we expect sum(counts) there to be `overall_rate` * n *
    190     # len(weights)...  with a stddev on that expectation equivalent to
    191     # performing (n * len(weights)) trials each with probability of
    192     # overall_rate.
    193     expected_overall_count = len(weights) * n * overall_rate
    194     actual_overall_count = sum(counts.values())
    195 
    196     stddev = math.sqrt(len(weights) * n * overall_rate * (1 - overall_rate))
    197 
    198     self.assertAlmostEqual(
    199         expected_overall_count,
    200         actual_overall_count,
    201         delta=(stddev * tol + abs_delta))
    202 
    203     # And we can form a similar expectation for each item -- it should
    204     # appear in the results a number of time proportional to its
    205     # weight, which is similar to performing `expected_overall_count`
    206     # trials each with a probability of weight/weight_sum.
    207     weight_sum = sum(weights)
    208     fractions = [w / weight_sum for w in weights]
    209     expected_counts = [expected_overall_count * f for f in fractions]
    210 
    211     stddevs = [
    212         math.sqrt(expected_overall_count * f * (1 - f)) for f in fractions
    213     ]
    214 
    215     for i in range(len(expected_counts)):
    216       expected_count = expected_counts[i]
    217       actual_count = counts[i]
    218       self.assertAlmostEqual(
    219           expected_count, actual_count, delta=(stddevs[i] * tol + abs_delta))
    220 
    221 
    222 if __name__ == '__main__':
    223   test.main()
    224