Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for ParameterizedTruncatedNormalOp."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import functools
     22 import math
     23 import timeit
     24 
     25 import numpy as np
     26 from six.moves import range  # pylint: disable=redefined-builtin
     27 
     28 from tensorflow.core.protobuf import config_pb2
     29 from tensorflow.python.client import session
     30 from tensorflow.python.framework import ops
     31 from tensorflow.python.framework import random_seed
     32 from tensorflow.python.ops import control_flow_ops
     33 from tensorflow.python.ops import random_ops
     34 from tensorflow.python.platform import test
     35 from tensorflow.python.platform import tf_logging
     36 
     37 
     38 class TruncatedNormalMoments(object):
     39   memoized_moments = None
     40   mean = None
     41   stddev = None
     42   minval = None
     43   maxval = None
     44 
     45   def __init__(self, mean, stddev, minval, maxval):
     46     self.memoized_moments = [1.0]  # 0th moment
     47     self.mean = np.double(mean)
     48     self.stddev = np.double(stddev)
     49     # NOTE(ringwalt): The formula doesn't handle infinite values.
     50     self.minval = np.double(max(-10, minval))
     51     self.maxval = np.double(min(10, maxval))
     52 
     53   def __getitem__(self, moment):
     54     """Calculates the truncated normal moments.
     55 
     56     Args:
     57       moment: The number for the moment.
     58 
     59     Returns:
     60       The value for the given moment.
     61 
     62     Uses the recurrence relation described in:
     63         http://www.smp.uq.edu.au/people/YoniNazarathy/teaching_projects
     64             /studentWork/EricOrjebin_TruncatedNormalMoments.pdf
     65     """
     66     assert moment > 0
     67     # The test case must ensure it can import scipy.stats before this point.
     68     import scipy.stats  # pylint: disable=g-import-not-at-top
     69     dist = scipy.stats.norm(loc=self.mean, scale=self.stddev)
     70     for k in range(len(self.memoized_moments), moment + 1):
     71       m_k_minus_2 = self.memoized_moments[k - 2] if k > 1 else np.double(0.0)
     72       m_k_minus_1 = self.memoized_moments[k - 1]
     73       numerator = (np.power(self.maxval, k - 1) * dist.pdf(self.maxval) -
     74                    np.power(self.minval, k - 1) * dist.pdf(self.minval))
     75       denominator = dist.cdf(self.maxval) - dist.cdf(self.minval)
     76       m = ((k - 1) * self.stddev**2 * m_k_minus_2 + self.mean * m_k_minus_1 -
     77            self.stddev * numerator / denominator)
     78       assert abs(m) < 1e50  # ensure numerical accuracy
     79       self.memoized_moments.append(m)
     80     return self.memoized_moments[moment]
     81 
     82 
     83 def calculate_moments(samples, max_moment):
     84   moments = [0.0] * (max_moment + 1)
     85   for sample in samples:
     86     value = 1.0
     87     for k in range(len(moments)):
     88       moments[k] += value
     89       value *= sample
     90   for i in range(len(moments)):
     91     moments[i] /= len(samples)
     92   return moments
     93 
     94 
     95 def z_test(real, expected, i, num_samples):
     96   numerical_error = 1e-6  # per-operation error
     97   moment_mean = expected[i]
     98   moment_squared = expected[2 * i]
     99   moment_var = moment_squared - moment_mean * moment_mean
    100 
    101   error_per_moment = i * numerical_error
    102   total_variance = moment_var / float(num_samples) + error_per_moment
    103   return abs((real[i] - moment_mean) / math.sqrt(total_variance))
    104 
    105 
    106 class ParameterizedTruncatedNormalTest(test.TestCase):
    107   z_limit = 6.0
    108 
    109   # Stop at moment 10 to avoid numerical errors in the theoretical moments.
    110   max_moment = 10
    111 
    112   def validateMoments(self, shape, mean, stddev, minval, maxval, seed=1618):
    113     try:
    114       # TruncatedNormalMoments requires scipy.stats.
    115       # Give up early if we are unable to import it.
    116       import scipy.stats  # pylint: disable=g-import-not-at-top,unused-variable
    117       random_seed.set_random_seed(seed)
    118       with self.test_session(use_gpu=True):
    119         samples = random_ops.parameterized_truncated_normal(shape, mean, stddev,
    120                                                             minval,
    121                                                             maxval).eval()
    122         assert (~np.isnan(samples)).all()
    123       moments = calculate_moments(samples, self.max_moment)
    124       expected_moments = TruncatedNormalMoments(mean, stddev, minval, maxval)
    125       num_samples = functools.reduce(lambda x, y: x * y, shape, 1)
    126       for i in range(1, len(moments)):
    127         self.assertLess(
    128             z_test(moments, expected_moments, i, num_samples), self.z_limit)
    129     except ImportError as e:
    130       tf_logging.warn("Cannot test truncated normal op: %s" % str(e))
    131 
    132   def validateKolmogorovSmirnov(self,
    133                                 shape,
    134                                 mean,
    135                                 stddev,
    136                                 minval,
    137                                 maxval,
    138                                 seed=1618):
    139     try:
    140       import scipy.stats  # pylint: disable=g-import-not-at-top
    141       random_seed.set_random_seed(seed)
    142       with self.test_session(use_gpu=True):
    143         samples = random_ops.parameterized_truncated_normal(shape, mean, stddev,
    144                                                             minval,
    145                                                             maxval).eval()
    146       assert (~np.isnan(samples)).all()
    147       minval = max(mean - stddev * 10, minval)
    148       maxval = min(mean + stddev * 10, maxval)
    149       dist = scipy.stats.norm(loc=mean, scale=stddev)
    150       cdf_min = dist.cdf(minval)
    151       cdf_max = dist.cdf(maxval)
    152 
    153       def truncated_cdf(x):
    154         return np.clip((dist.cdf(x) - cdf_min) / (cdf_max - cdf_min), 0.0, 1.0)
    155 
    156       pvalue = scipy.stats.kstest(samples, truncated_cdf)[1]
    157       self.assertGreater(pvalue, 1e-10)
    158     except ImportError as e:
    159       tf_logging.warn("Cannot test truncated normal op: %s" % str(e))
    160 
    161   def testDefaults(self):
    162     self.validateMoments([10**5], 0.0, 1.0, -2.0, 2.0)
    163 
    164   def testShifted(self):
    165     self.validateMoments([10**5], -1.0, 1.0, -2.0, 2.0)
    166 
    167   def testRightTail(self):
    168     self.validateMoments([10**5], 0.0, 1.0, 4.0, np.infty)
    169 
    170   def testLeftTail(self):
    171     self.validateMoments([10**5], 0.0, 1.0, -np.infty, -4.0)
    172 
    173   def testLeftTailTwoSidedBounds(self):
    174     self.validateMoments([10**5], 0.0, 1.0, -6.0, -3.0)
    175 
    176   def testTwoSidedLeftTailShifted(self):
    177     self.validateKolmogorovSmirnov([10**5], 6.0, 1.0, -1.0, 1.0)
    178 
    179   def testRightTailShifted(self):
    180     self.validateMoments([10**5], -5.0, 1.0, 2.0, np.infty)
    181 
    182   def testSmallStddev(self):
    183     self.validateKolmogorovSmirnov([10**5], 0.0, 0.1, 0.05, 0.10)
    184 
    185 
    186 # Benchmarking code
    187 def parameterized_vs_naive(shape, num_iters, use_gpu=False):
    188   np.random.seed(1618)  # Make it reproducible.
    189 
    190   # No CSE/CF.
    191   optimizer_options = config_pb2.OptimizerOptions(
    192       opt_level=config_pb2.OptimizerOptions.L0)
    193   config = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions(
    194       optimizer_options=optimizer_options))
    195 
    196   with session.Session(config=config) as sess:
    197     with ops.device("/cpu:0" if not use_gpu else None):
    198       param_op = control_flow_ops.group(
    199           random_ops.parameterized_truncated_normal(shape))
    200       naive_op = control_flow_ops.group(random_ops.truncated_normal(shape))
    201 
    202     # Burn-in to avoid session setup costs in the timing.
    203     sess.run(param_op)
    204     sess.run(param_op)
    205     param_dt = timeit.timeit(lambda: sess.run(param_op), number=num_iters)
    206     sess.run(naive_op)
    207     sess.run(naive_op)
    208     naive_dt = timeit.timeit(lambda: sess.run(naive_op), number=num_iters)
    209     return param_dt, naive_dt
    210 
    211 
    212 class TruncatedNormalBenchmark(test.Benchmark):
    213 
    214   def benchmarkParameterizedOpVsNaiveOpCpu(self):
    215     self._benchmarkParameterizedOpVsNaiveOp(False)
    216 
    217   def benchmarkParameterizedOpVsNaiveOpGpu(self):
    218     self._benchmarkParameterizedOpVsNaiveOp(True)
    219 
    220   def _benchmarkParameterizedOpVsNaiveOp(self, use_gpu):
    221     num_iters = 50
    222     print(("Composition of new ParameterizedTruncatedNormalOp vs. "
    223            "naive TruncatedNormalOp [%d iters]") % num_iters)
    224     print("Shape\tsec(parameterized)\tsec(naive)\tspeedup")
    225 
    226     for shape in [[10000, 100], [1000, 1000], [1000000], [100, 100, 100],
    227                   [20, 20, 20, 20]]:
    228       p_dt, n_dt = parameterized_vs_naive(shape, num_iters, use_gpu)
    229       print("%s\t%.3f\t%.3f\t%.2f" % (shape, p_dt, n_dt, p_dt / n_dt))
    230 
    231       shape_str = "-".join(map(str, shape))
    232       self.report_benchmark(
    233           name="parameterized_shape" + shape_str,
    234           iters=num_iters,
    235           wall_time=p_dt)
    236       self.report_benchmark(
    237           name="naive_shape" + shape_str, iters=num_iters, wall_time=n_dt)
    238 
    239 
    240 if __name__ == "__main__":
    241   test.main()
    242