1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Benchmarks for `tf.data.experimental.rejection_resample()`.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import time 21 22 import numpy as np 23 from six.moves import xrange # pylint: disable=redefined-builtin 24 25 from tensorflow.python.client import session 26 from tensorflow.python.data.experimental.ops import resampling 27 from tensorflow.python.data.ops import dataset_ops 28 from tensorflow.python.platform import test 29 30 31 def _time_resampling(data_np, target_dist, init_dist, num_to_sample): # pylint: disable=missing-docstring 32 dataset = dataset_ops.Dataset.from_tensor_slices(data_np).repeat() 33 34 # Reshape distribution via rejection sampling. 35 dataset = dataset.apply( 36 resampling.rejection_resample( 37 class_func=lambda x: x, 38 target_dist=target_dist, 39 initial_dist=init_dist, 40 seed=142)) 41 42 options = dataset_ops.Options() 43 options.experimental_optimization.apply_default_optimizations = False 44 dataset = dataset.with_options(options) 45 get_next = dataset_ops.make_one_shot_iterator(dataset).get_next() 46 47 with session.Session() as sess: 48 start_time = time.time() 49 for _ in xrange(num_to_sample): 50 sess.run(get_next) 51 end_time = time.time() 52 53 return end_time - start_time 54 55 56 class RejectionResampleBenchmark(test.Benchmark): 57 """Benchmarks for `tf.data.experimental.rejection_resample()`.""" 58 59 def benchmarkResamplePerformance(self): 60 init_dist = [0.25, 0.25, 0.25, 0.25] 61 target_dist = [0.0, 0.0, 0.0, 1.0] 62 num_classes = len(init_dist) 63 # We don't need many samples to test a dirac-delta target distribution 64 num_samples = 1000 65 data_np = np.random.choice(num_classes, num_samples, p=init_dist) 66 67 resample_time = _time_resampling( 68 data_np, target_dist, init_dist, num_to_sample=1000) 69 70 self.report_benchmark(iters=1000, wall_time=resample_time, name="resample") 71 72 73 if __name__ == "__main__": 74 test.main() 75