Home | History | Annotate | Download | only in ops
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Benchmark for split and grad of split."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import itertools
     22 import random
     23 import time
     24 
     25 from tensorflow.core.protobuf import config_pb2
     26 from tensorflow.python.client import session as session_lib
     27 from tensorflow.python.framework import ops
     28 from tensorflow.python.ops import array_ops
     29 from tensorflow.python.ops import control_flow_ops
     30 from tensorflow.python.ops import gradients_impl
     31 from tensorflow.python.ops import variables
     32 from tensorflow.python.platform import test
     33 
     34 
     35 def build_graph(device, input_shape, variable, num_inputs, axis, grad):
     36   """Build a graph containing a sequence of concat operations.
     37 
     38   Args:
     39     device: string, the device to run on.
     40     input_shape: shape of the input tensors.
     41     variable: whether or not to randomize the input shape
     42     num_inputs: the number of inputs to concat
     43     axis: axis to be concat'ed
     44     grad: if True compute the gradient
     45 
     46   Returns:
     47     An array of tensors to run()
     48   """
     49   with ops.device("/%s:0" % device):
     50     if not variable:
     51       inputs = [array_ops.zeros(input_shape) for _ in range(num_inputs)]
     52     else:
     53       if axis == 1:
     54         inputs = [
     55             array_ops.zeros([
     56                 input_shape[0],
     57                 random.randint(max(1, input_shape[1] - 5), input_shape[1] + 5)
     58             ]) for _ in range(num_inputs)
     59         ]
     60       else:
     61         inputs = [
     62             array_ops.zeros([
     63                 random.randint(max(1, input_shape[0] - 5), input_shape[0] + 5),
     64                 input_shape[1]
     65             ]) for _ in range(num_inputs)
     66         ]
     67 
     68     outputs = [array_ops.concat(inputs, axis) for _ in range(100)]
     69     if grad:
     70       return control_flow_ops.group(*list(
     71           itertools.chain.from_iterable([
     72               gradients_impl.gradients(output, inputs) for output in outputs
     73           ])))
     74     else:
     75       return control_flow_ops.group(*outputs)
     76 
     77 
     78 class ConcatBenchmark(test.Benchmark):
     79   """Benchmark concat."""
     80 
     81   def _run_graph(self, device, input_shape, variable, num_inputs, axis, grad,
     82                  num_iters):
     83     """Run the graph and print its execution time.
     84 
     85     Args:
     86       device: string, the device to run on.
     87       input_shape: shape of the input tensors.
     88       variable: whether or not the input shape should be fixed
     89       num_inputs: the number of inputs to concat
     90       axis: axis to be concat'ed
     91       grad: if True compute the gradient
     92       num_iters: number of steps to run.
     93 
     94     Returns:
     95       The duration of the run in seconds.
     96     """
     97     graph = ops.Graph()
     98     with graph.as_default():
     99       outputs = build_graph(device, input_shape, variable, num_inputs, axis,
    100                             grad)
    101     config = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions(
    102         optimizer_options=config_pb2.OptimizerOptions(
    103             opt_level=config_pb2.OptimizerOptions.L0)))
    104     with session_lib.Session(graph=graph, config=config) as session:
    105       variables.global_variables_initializer().run()
    106       _ = session.run(outputs)  # warm up.
    107       start_time = time.time()
    108       for _ in range(num_iters):
    109         _ = session.run(outputs)
    110       duration = time.time() - start_time
    111       print("%s shape:%d/%d var: %r #inputs:%d axis:%d grad:%r - %f secs - %f "
    112             "GB/sec" % (device, input_shape[0], input_shape[1], variable,
    113                         num_inputs, axis, grad, duration / num_iters,
    114                         num_inputs * input_shape[0] * input_shape[1] * 4 * 2 *
    115                         100 / (duration / num_iters) / 1e9))
    116 
    117     name_template = (
    118         "concat_bench_{device}_input_shape_{shape}_variable_{variable}"
    119         "_num_inputs_{num_inputs}_axis_{axis}_grad_{grad}")
    120 
    121     self.report_benchmark(name=name_template.format(
    122         device=device,
    123         num_inputs=num_inputs,
    124         variable=variable,
    125         grad=grad,
    126         shape=str(input_shape).replace(" ", ""),
    127         axis=str(axis),
    128         iters=num_iters))
    129 
    130     return duration
    131 
    132   def benchmark_concat(self):
    133     print("Forward vs backward concat")
    134     shapes = [[2000, 8], [8, 2000], [100, 18], [1000, 18], [100, 97],
    135               [1000, 97], [10000, 1], [1, 10000]]
    136     axis_ = [0, 1]
    137     num_inputs = 20
    138     num_iters = [10] * len(shapes)
    139     variable = [False, True]  # fixed input size or not
    140     for shape, iters in zip(shapes, num_iters):
    141       for axis in axis_:
    142         for v in variable:
    143           self._run_graph("cpu", shape, v, num_inputs, axis, True, iters)
    144 
    145 
    146 if __name__ == "__main__":
    147   test.main()
    148