1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for ParameterizedTruncatedNormalOp.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import functools 22 import math 23 import timeit 24 25 import numpy as np 26 from six.moves import range # pylint: disable=redefined-builtin 27 28 from tensorflow.core.protobuf import config_pb2 29 from tensorflow.python.client import session 30 from tensorflow.python.framework import ops 31 from tensorflow.python.framework import random_seed 32 from tensorflow.python.ops import control_flow_ops 33 from tensorflow.python.ops import random_ops 34 from tensorflow.python.platform import test 35 from tensorflow.python.platform import tf_logging 36 37 38 class TruncatedNormalMoments(object): 39 memoized_moments = None 40 mean = None 41 stddev = None 42 minval = None 43 maxval = None 44 45 def __init__(self, mean, stddev, minval, maxval): 46 self.memoized_moments = [1.0] # 0th moment 47 self.mean = np.double(mean) 48 self.stddev = np.double(stddev) 49 # NOTE(ringwalt): The formula doesn't handle infinite values. 50 self.minval = np.double(max(-10, minval)) 51 self.maxval = np.double(min(10, maxval)) 52 53 def __getitem__(self, moment): 54 """Calculates the truncated normal moments. 55 56 Args: 57 moment: The number for the moment. 58 59 Returns: 60 The value for the given moment. 61 62 Uses the recurrence relation described in: 63 http://www.smp.uq.edu.au/people/YoniNazarathy/teaching_projects 64 /studentWork/EricOrjebin_TruncatedNormalMoments.pdf 65 """ 66 assert moment > 0 67 # The test case must ensure it can import scipy.stats before this point. 68 import scipy.stats # pylint: disable=g-import-not-at-top 69 dist = scipy.stats.norm(loc=self.mean, scale=self.stddev) 70 for k in range(len(self.memoized_moments), moment + 1): 71 m_k_minus_2 = self.memoized_moments[k - 2] if k > 1 else np.double(0.0) 72 m_k_minus_1 = self.memoized_moments[k - 1] 73 numerator = (np.power(self.maxval, k - 1) * dist.pdf(self.maxval) - 74 np.power(self.minval, k - 1) * dist.pdf(self.minval)) 75 denominator = dist.cdf(self.maxval) - dist.cdf(self.minval) 76 m = ((k - 1) * self.stddev**2 * m_k_minus_2 + self.mean * m_k_minus_1 - 77 self.stddev * numerator / denominator) 78 assert abs(m) < 1e50 # ensure numerical accuracy 79 self.memoized_moments.append(m) 80 return self.memoized_moments[moment] 81 82 83 def calculate_moments(samples, max_moment): 84 moments = [0.0] * (max_moment + 1) 85 for sample in samples: 86 value = 1.0 87 for k in range(len(moments)): 88 moments[k] += value 89 value *= sample 90 for i in range(len(moments)): 91 moments[i] /= len(samples) 92 return moments 93 94 95 def z_test(real, expected, i, num_samples): 96 numerical_error = 1e-6 # per-operation error 97 moment_mean = expected[i] 98 moment_squared = expected[2 * i] 99 moment_var = moment_squared - moment_mean * moment_mean 100 101 error_per_moment = i * numerical_error 102 total_variance = moment_var / float(num_samples) + error_per_moment 103 return abs((real[i] - moment_mean) / math.sqrt(total_variance)) 104 105 106 class ParameterizedTruncatedNormalTest(test.TestCase): 107 z_limit = 6.0 108 109 # Stop at moment 10 to avoid numerical errors in the theoretical moments. 110 max_moment = 10 111 112 def validateMoments(self, shape, mean, stddev, minval, maxval, seed=1618): 113 try: 114 # TruncatedNormalMoments requires scipy.stats. 115 # Give up early if we are unable to import it. 116 import scipy.stats # pylint: disable=g-import-not-at-top,unused-variable 117 random_seed.set_random_seed(seed) 118 with self.test_session(use_gpu=True): 119 samples = random_ops.parameterized_truncated_normal(shape, mean, stddev, 120 minval, 121 maxval).eval() 122 assert (~np.isnan(samples)).all() 123 moments = calculate_moments(samples, self.max_moment) 124 expected_moments = TruncatedNormalMoments(mean, stddev, minval, maxval) 125 num_samples = functools.reduce(lambda x, y: x * y, shape, 1) 126 for i in range(1, len(moments)): 127 self.assertLess( 128 z_test(moments, expected_moments, i, num_samples), self.z_limit) 129 except ImportError as e: 130 tf_logging.warn("Cannot test truncated normal op: %s" % str(e)) 131 132 def validateKolmogorovSmirnov(self, 133 shape, 134 mean, 135 stddev, 136 minval, 137 maxval, 138 seed=1618): 139 try: 140 import scipy.stats # pylint: disable=g-import-not-at-top 141 random_seed.set_random_seed(seed) 142 with self.test_session(use_gpu=True): 143 samples = random_ops.parameterized_truncated_normal(shape, mean, stddev, 144 minval, 145 maxval).eval() 146 assert (~np.isnan(samples)).all() 147 minval = max(mean - stddev * 10, minval) 148 maxval = min(mean + stddev * 10, maxval) 149 dist = scipy.stats.norm(loc=mean, scale=stddev) 150 cdf_min = dist.cdf(minval) 151 cdf_max = dist.cdf(maxval) 152 153 def truncated_cdf(x): 154 return np.clip((dist.cdf(x) - cdf_min) / (cdf_max - cdf_min), 0.0, 1.0) 155 156 pvalue = scipy.stats.kstest(samples, truncated_cdf)[1] 157 self.assertGreater(pvalue, 1e-10) 158 except ImportError as e: 159 tf_logging.warn("Cannot test truncated normal op: %s" % str(e)) 160 161 def testDefaults(self): 162 self.validateMoments([10**5], 0.0, 1.0, -2.0, 2.0) 163 164 def testShifted(self): 165 self.validateMoments([10**5], -1.0, 1.0, -2.0, 2.0) 166 167 def testRightTail(self): 168 self.validateMoments([10**5], 0.0, 1.0, 4.0, np.infty) 169 170 def testLeftTail(self): 171 self.validateMoments([10**5], 0.0, 1.0, -np.infty, -4.0) 172 173 def testLeftTailTwoSidedBounds(self): 174 self.validateMoments([10**5], 0.0, 1.0, -6.0, -3.0) 175 176 def testTwoSidedLeftTailShifted(self): 177 self.validateKolmogorovSmirnov([10**5], 6.0, 1.0, -1.0, 1.0) 178 179 def testRightTailShifted(self): 180 self.validateMoments([10**5], -5.0, 1.0, 2.0, np.infty) 181 182 def testSmallStddev(self): 183 self.validateKolmogorovSmirnov([10**5], 0.0, 0.1, 0.05, 0.10) 184 185 186 # Benchmarking code 187 def parameterized_vs_naive(shape, num_iters, use_gpu=False): 188 np.random.seed(1618) # Make it reproducible. 189 190 # No CSE/CF. 191 optimizer_options = config_pb2.OptimizerOptions( 192 opt_level=config_pb2.OptimizerOptions.L0) 193 config = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions( 194 optimizer_options=optimizer_options)) 195 196 with session.Session(config=config) as sess: 197 with ops.device("/cpu:0" if not use_gpu else None): 198 param_op = control_flow_ops.group( 199 random_ops.parameterized_truncated_normal(shape)) 200 naive_op = control_flow_ops.group(random_ops.truncated_normal(shape)) 201 202 # Burn-in to avoid session setup costs in the timing. 203 sess.run(param_op) 204 sess.run(param_op) 205 param_dt = timeit.timeit(lambda: sess.run(param_op), number=num_iters) 206 sess.run(naive_op) 207 sess.run(naive_op) 208 naive_dt = timeit.timeit(lambda: sess.run(naive_op), number=num_iters) 209 return param_dt, naive_dt 210 211 212 class TruncatedNormalBenchmark(test.Benchmark): 213 214 def benchmarkParameterizedOpVsNaiveOpCpu(self): 215 self._benchmarkParameterizedOpVsNaiveOp(False) 216 217 def benchmarkParameterizedOpVsNaiveOpGpu(self): 218 self._benchmarkParameterizedOpVsNaiveOp(True) 219 220 def _benchmarkParameterizedOpVsNaiveOp(self, use_gpu): 221 num_iters = 50 222 print(("Composition of new ParameterizedTruncatedNormalOp vs. " 223 "naive TruncatedNormalOp [%d iters]") % num_iters) 224 print("Shape\tsec(parameterized)\tsec(naive)\tspeedup") 225 226 for shape in [[10000, 100], [1000, 1000], [1000000], [100, 100, 100], 227 [20, 20, 20, 20]]: 228 p_dt, n_dt = parameterized_vs_naive(shape, num_iters, use_gpu) 229 print("%s\t%.3f\t%.3f\t%.2f" % (shape, p_dt, n_dt, p_dt / n_dt)) 230 231 shape_str = "-".join(map(str, shape)) 232 self.report_benchmark( 233 name="parameterized_shape" + shape_str, 234 iters=num_iters, 235 wall_time=p_dt) 236 self.report_benchmark( 237 name="naive_shape" + shape_str, iters=num_iters, wall_time=n_dt) 238 239 240 if __name__ == "__main__": 241 test.main() 242