1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests and benchmarks for creating RPC clusters on localhost.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import time 22 23 import numpy as np 24 25 from tensorflow.python.client import session as session_lib 26 from tensorflow.python.framework import dtypes 27 from tensorflow.python.framework import ops 28 from tensorflow.python.ops import partitioned_variables 29 from tensorflow.python.ops import variable_scope 30 from tensorflow.python.ops import variables 31 from tensorflow.python.platform import test 32 from tensorflow.python.training import device_setter 33 34 35 class CreateLocalClusterTest(test.TestCase): 36 37 def testCreateLocalCluster(self): 38 workers, _ = test.create_local_cluster(num_workers=2, num_ps=2) 39 worker_sessions = [session_lib.Session(w.target) for w in workers] 40 with ops.device("/job:ps/task:0"): 41 var0 = variables.Variable(0.0) 42 with ops.device("/job:ps/task:1"): 43 var1 = variables.Variable(1.0) 44 worker_sessions[0].run([var0.initializer, var1.initializer]) 45 with ops.device("/job:ps/task:0"): 46 var2 = variables.Variable(2.0) 47 with ops.device("/job:ps/task:1"): 48 var3 = variables.Variable(3.0) 49 worker_sessions[1].run([var2.initializer, var3.initializer]) 50 51 # Read values back in the opposite session 52 self.assertAllEqual(0.0, var0.eval(session=worker_sessions[1])) 53 self.assertAllEqual(1.0, var1.eval(session=worker_sessions[1])) 54 self.assertAllEqual(2.0, var2.eval(session=worker_sessions[0])) 55 self.assertAllEqual(3.0, var3.eval(session=worker_sessions[0])) 56 57 58 class CreateLocalClusterBenchmark(test.Benchmark): 59 60 def benchmarkCreateLocalCluster(self): 61 deltas = [] 62 iters = 5 63 for _ in range(iters): 64 start_time = time.time() 65 test.create_local_cluster(num_workers=1, num_ps=10) 66 end_time = time.time() 67 deltas.append(end_time - start_time) 68 69 median_deltas = np.median(deltas) 70 print("\n\nbenchmark_create_local_cluster_1_worker_10_ps. " 71 "iterations: %d, median wall time: %g\n\n" % (iters, median_deltas)) 72 self.report_benchmark( 73 iters=iters, 74 wall_time=median_deltas, 75 name="benchmark_create_local_cluster_1_worker_10_ps") 76 77 78 class PartitionedVariablesBenchmark(test.Benchmark): 79 80 def benchmark_create_1000_partitions_with_100_parameter_servers(self): 81 workers, _ = test.create_local_cluster(num_workers=1, num_ps=100) 82 worker_sessions = [session_lib.Session(w.target) for w in workers] 83 worker = worker_sessions[0] 84 partition_sizes = (1, 512, 1024 * 32, 1024 * 128) 85 86 partitioned = [] 87 88 for partition_size in partition_sizes: 89 # max_shard_bytes is 4, shape is 1000*partition_size float32s which should 90 # partition into 1000 shards, each containing partition_size float32s. 91 print("Building partitioned variable with %d floats per partition" % 92 partition_size) 93 with ops.device(device_setter.replica_device_setter(ps_tasks=100)): 94 partitioned_ix = variable_scope.get_variable( 95 "partitioned_%d" % partition_size, 96 shape=[1000 * partition_size], 97 dtype=dtypes.float32, 98 # Each partition to have exactly N float32s 99 partitioner=partitioned_variables.variable_axis_size_partitioner( 100 max_shard_bytes=4 * partition_size)) 101 # Concatenates along axis 0 102 partitioned.append(ops.convert_to_tensor(partitioned_ix)) 103 104 variables.global_variables_initializer().run(session=worker) 105 106 for ix, partition_size in enumerate(partition_sizes): 107 print("Running benchmark having partitions with %d floats" % 108 partition_size) 109 self.run_op_benchmark( 110 worker, 111 partitioned[ix], 112 name=("read_concat_1000_partitions_from_" 113 "100_parameter_servers_partsize_%d_floats" % partition_size)) 114 115 116 if __name__ == "__main__": 117 test.main() 118