1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import collections 21 import math 22 23 import numpy 24 from tensorflow.contrib.training.python.training import resample 25 from tensorflow.python.framework import constant_op 26 from tensorflow.python.framework import dtypes 27 from tensorflow.python.ops import array_ops 28 from tensorflow.python.ops import control_flow_ops 29 from tensorflow.python.ops import math_ops 30 from tensorflow.python.ops import variables 31 from tensorflow.python.platform import test 32 33 34 class ResampleTest(test.TestCase): 35 """Tests that resampling runs and outputs are close to expected values.""" 36 37 def testRepeatRange(self): 38 cases = [ 39 ([], []), 40 ([0], []), 41 ([1], [0]), 42 ([1, 0], [0]), 43 ([0, 1], [1]), 44 ([3], [0, 0, 0]), 45 ([0, 1, 2, 3], [1, 2, 2, 3, 3, 3]), 46 ] 47 with self.test_session() as sess: 48 for inputs, expected in cases: 49 array_inputs = numpy.array(inputs, dtype=numpy.int32) 50 actual = sess.run(resample._repeat_range(array_inputs)) 51 self.assertAllEqual(actual, expected) 52 53 def testRoundtrip(self, rate=0.25, count=5, n=500): 54 """Tests `resample(x, weights)` and resample(resample(x, rate), 1/rate)`.""" 55 56 foo = self.get_values(count) 57 bar = self.get_values(count) 58 weights = self.get_weights(count) 59 60 resampled_in, rates = resample.weighted_resample( 61 [foo, bar], constant_op.constant(weights), rate, seed=123) 62 63 resampled_back_out = resample.resample_at_rate( 64 resampled_in, 1.0 / rates, seed=456) 65 66 init = control_flow_ops.group(variables.local_variables_initializer(), 67 variables.global_variables_initializer()) 68 with self.test_session() as s: 69 s.run(init) # initialize 70 71 # outputs 72 counts_resampled = collections.Counter() 73 counts_reresampled = collections.Counter() 74 for _ in range(n): 75 resampled_vs, reresampled_vs = s.run([resampled_in, resampled_back_out]) 76 77 self.assertAllEqual(resampled_vs[0], resampled_vs[1]) 78 self.assertAllEqual(reresampled_vs[0], reresampled_vs[1]) 79 80 for v in resampled_vs[0]: 81 counts_resampled[v] += 1 82 for v in reresampled_vs[0]: 83 counts_reresampled[v] += 1 84 85 # assert that resampling worked as expected 86 self.assert_expected(weights, rate, counts_resampled, n) 87 88 # and that re-resampling gives the approx identity. 89 self.assert_expected( 90 [1.0 for _ in weights], 91 1.0, 92 counts_reresampled, 93 n, 94 abs_delta=0.1 * n * count) 95 96 def testCorrectRates(self, rate=0.25, count=10, n=500, rtol=0.1): 97 """Tests that the rates returned by weighted_resample are correct.""" 98 99 # The approach here is to verify that: 100 # - sum(1/rate) approximates the size of the original collection 101 # - sum(1/rate * value) approximates the sum of the original inputs, 102 # - sum(1/rate * value)/sum(1/rate) approximates the mean. 103 vals = self.get_values(count) 104 weights = self.get_weights(count) 105 106 resampled, rates = resample.weighted_resample([vals], 107 constant_op.constant(weights), 108 rate) 109 110 invrates = 1.0 / rates 111 112 init = control_flow_ops.group(variables.local_variables_initializer(), 113 variables.global_variables_initializer()) 114 expected_sum_op = math_ops.reduce_sum(vals) 115 with self.test_session() as s: 116 s.run(init) 117 expected_sum = n * s.run(expected_sum_op) 118 119 weight_sum = 0.0 120 weighted_value_sum = 0.0 121 for _ in range(n): 122 val, inv_rate = s.run([resampled[0], invrates]) 123 weight_sum += sum(inv_rate) 124 weighted_value_sum += sum(val * inv_rate) 125 126 # sum(inv_rate) ~= N*count: 127 expected_count = count * n 128 self.assertAlmostEqual( 129 expected_count, weight_sum, delta=(rtol * expected_count)) 130 131 # sum(vals) * n ~= weighted_sum(resampled, 1.0/weights) 132 self.assertAlmostEqual( 133 expected_sum, weighted_value_sum, delta=(rtol * expected_sum)) 134 135 # Mean ~= weighted mean: 136 expected_mean = expected_sum / float(n * count) 137 self.assertAlmostEqual( 138 expected_mean, 139 weighted_value_sum / weight_sum, 140 delta=(rtol * expected_mean)) 141 142 def testZeroRateUnknownShapes(self, count=10): 143 """Tests that resampling runs with completely runtime shapes.""" 144 # Use placeholcers without shape set: 145 vals = array_ops.placeholder(dtype=dtypes.int32) 146 rates = array_ops.placeholder(dtype=dtypes.float32) 147 148 resampled = resample.resample_at_rate([vals], rates) 149 150 with self.test_session() as s: 151 rs, = s.run(resampled, { 152 vals: list(range(count)), 153 rates: numpy.zeros( 154 shape=[count], dtype=numpy.float32) 155 }) 156 self.assertEqual(rs.shape, (0,)) 157 158 def testDtypes(self, count=10): 159 """Test that we can define the ops with float64 weights.""" 160 161 vals = self.get_values(count) 162 weights = math_ops.cast(self.get_weights(count), dtypes.float64) 163 164 # should not error: 165 resample.resample_at_rate([vals], weights) 166 resample.weighted_resample( 167 [vals], weights, overall_rate=math_ops.cast(1.0, dtypes.float64)) 168 169 def get_weights(self, n, mean=10.0, stddev=5): 170 """Returns random positive weight values.""" 171 assert mean > 0, 'Weights have to be positive.' 172 results = [] 173 while len(results) < n: 174 v = numpy.random.normal(mean, stddev) 175 if v > 0: 176 results.append(v) 177 return results 178 179 def get_values(self, n): 180 return constant_op.constant(list(range(n))) 181 182 def assert_expected(self, 183 weights, 184 overall_rate, 185 counts, 186 n, 187 tol=2.0, 188 abs_delta=0): 189 # Overall, we expect sum(counts) there to be `overall_rate` * n * 190 # len(weights)... with a stddev on that expectation equivalent to 191 # performing (n * len(weights)) trials each with probability of 192 # overall_rate. 193 expected_overall_count = len(weights) * n * overall_rate 194 actual_overall_count = sum(counts.values()) 195 196 stddev = math.sqrt(len(weights) * n * overall_rate * (1 - overall_rate)) 197 198 self.assertAlmostEqual( 199 expected_overall_count, 200 actual_overall_count, 201 delta=(stddev * tol + abs_delta)) 202 203 # And we can form a similar expectation for each item -- it should 204 # appear in the results a number of time proportional to its 205 # weight, which is similar to performing `expected_overall_count` 206 # trials each with a probability of weight/weight_sum. 207 weight_sum = sum(weights) 208 fractions = [w / weight_sum for w in weights] 209 expected_counts = [expected_overall_count * f for f in fractions] 210 211 stddevs = [ 212 math.sqrt(expected_overall_count * f * (1 - f)) for f in fractions 213 ] 214 215 for i in range(len(expected_counts)): 216 expected_count = expected_counts[i] 217 actual_count = counts[i] 218 self.assertAlmostEqual( 219 expected_count, actual_count, delta=(stddevs[i] * tol + abs_delta)) 220 221 222 if __name__ == '__main__': 223 test.main() 224