Home | History | Annotate | Download | only in tools
      1 # Description:
      2 #   Tools for manipulating TensorFlow graphs.
      3 
      4 package(default_visibility = ["//visibility:public"])
      5 
      6 licenses(["notice"])  # Apache 2.0
      7 
      8 exports_files(["LICENSE"])
      9 
     10 load("//tensorflow:tensorflow.bzl", "py_test")
     11 
     12 # Transitive dependencies of this target will be included in the pip package.
     13 py_library(
     14     name = "tools_pip",
     15     deps = [
     16         ":freeze_graph",
     17         ":inspect_checkpoint",
     18         ":optimize_for_inference",
     19         ":print_selective_registration_header",
     20         ":saved_model_cli",
     21         ":saved_model_utils",
     22         ":strip_unused",
     23     ],
     24 )
     25 
     26 py_library(
     27     name = "saved_model_utils",
     28     srcs = ["saved_model_utils.py"],
     29     srcs_version = "PY2AND3",
     30     deps = ["//tensorflow:tensorflow_py"],
     31 )
     32 
     33 py_library(
     34     name = "freeze_graph_lib",
     35     srcs = ["freeze_graph.py"],
     36     srcs_version = "PY2AND3",
     37     deps = [
     38         ":saved_model_utils",
     39         "//tensorflow/core:protos_all_py",
     40         "//tensorflow/python",  # TODO(b/34059704): remove when fixed
     41         "//tensorflow/python:client",
     42         "//tensorflow/python:framework",
     43         "//tensorflow/python:platform",
     44         "//tensorflow/python:training",
     45         "@six_archive//:six",
     46     ],
     47 )
     48 
     49 py_binary(
     50     name = "freeze_graph",
     51     srcs = ["freeze_graph.py"],
     52     srcs_version = "PY2AND3",
     53     deps = [
     54         ":saved_model_utils",
     55         "//tensorflow/core:protos_all_py",
     56         "//tensorflow/python",  # TODO(b/34059704): remove when fixed
     57         "//tensorflow/python:client",
     58         "//tensorflow/python:framework",
     59         "//tensorflow/python:platform",
     60         "//tensorflow/python:training",
     61         "@six_archive//:six",
     62     ],
     63 )
     64 
     65 py_binary(
     66     name = "import_pb_to_tensorboard",
     67     srcs = ["import_pb_to_tensorboard.py"],
     68     srcs_version = "PY2AND3",
     69     deps = [
     70         "//tensorflow/core:protos_all_py",
     71         "//tensorflow/python:client",
     72         "//tensorflow/python:framework",
     73         "//tensorflow/python:framework_ops",
     74         "//tensorflow/python:platform",
     75         "//tensorflow/python:summary",
     76     ],
     77 )
     78 
     79 py_test(
     80     name = "freeze_graph_test",
     81     size = "small",
     82     srcs = ["freeze_graph_test.py"],
     83     srcs_version = "PY2AND3",
     84     deps = [
     85         ":freeze_graph",
     86         "//tensorflow/core:protos_all_py",
     87         "//tensorflow/python:client",
     88         "//tensorflow/python:client_testlib",
     89         "//tensorflow/python:framework",
     90         "//tensorflow/python:framework_for_generated_wrappers",
     91         "//tensorflow/python:framework_test_lib",
     92         "//tensorflow/python:math_ops",
     93         "//tensorflow/python:training",
     94         "//tensorflow/python:variables",
     95     ],
     96 )
     97 
     98 py_binary(
     99     name = "inspect_checkpoint",
    100     srcs = ["inspect_checkpoint.py"],
    101     srcs_version = "PY2AND3",
    102     deps = [
    103         "//tensorflow/python",  # TODO(b/34059704): remove when fixed
    104         "//tensorflow/python:platform",
    105         "//tensorflow/python:pywrap_tensorflow",
    106     ],
    107 )
    108 
    109 py_library(
    110     name = "strip_unused_lib",
    111     srcs = ["strip_unused_lib.py"],
    112     srcs_version = "PY2AND3",
    113     deps = [
    114         "//tensorflow/core:protos_all_py",
    115         "//tensorflow/python:framework",
    116         "//tensorflow/python:platform",
    117     ],
    118 )
    119 
    120 py_binary(
    121     name = "strip_unused",
    122     srcs = ["strip_unused.py"],
    123     srcs_version = "PY2AND3",
    124     deps = [
    125         ":strip_unused_lib",
    126         "//tensorflow/python:framework_for_generated_wrappers",
    127         "//tensorflow/python:platform",
    128         "@six_archive//:six",
    129     ],
    130 )
    131 
    132 py_test(
    133     name = "strip_unused_test",
    134     size = "small",
    135     srcs = ["strip_unused_test.py"],
    136     srcs_version = "PY2AND3",
    137     deps = [
    138         ":strip_unused_lib",
    139         "//tensorflow/core:protos_all_py",
    140         "//tensorflow/python:client",
    141         "//tensorflow/python:client_testlib",
    142         "//tensorflow/python:framework",
    143         "//tensorflow/python:framework_for_generated_wrappers",
    144         "//tensorflow/python:framework_test_lib",
    145         "//tensorflow/python:math_ops",
    146     ],
    147 )
    148 
    149 py_library(
    150     name = "optimize_for_inference_lib",
    151     srcs = ["optimize_for_inference_lib.py"],
    152     srcs_version = "PY2AND3",
    153     deps = [
    154         ":strip_unused_lib",
    155         "//tensorflow/core:protos_all_py",
    156         "//tensorflow/python:framework",
    157         "//tensorflow/python:framework_for_generated_wrappers",
    158         "//tensorflow/python:platform",
    159         "//third_party/py/numpy",
    160         "@six_archive//:six",
    161     ],
    162 )
    163 
    164 py_binary(
    165     name = "optimize_for_inference",
    166     srcs = ["optimize_for_inference.py"],
    167     srcs_version = "PY2AND3",
    168     deps = [
    169         ":optimize_for_inference_lib",
    170         "//tensorflow/core:protos_all_py",
    171         "//tensorflow/python",  # TODO(b/34059704): remove when fixed
    172         "//tensorflow/python:framework",
    173         "//tensorflow/python:framework_for_generated_wrappers",
    174         "//tensorflow/python:platform",
    175         "@six_archive//:six",
    176     ],
    177 )
    178 
    179 py_test(
    180     name = "optimize_for_inference_test",
    181     size = "small",
    182     srcs = ["optimize_for_inference_test.py"],
    183     srcs_version = "PY2AND3",
    184     deps = [
    185         ":optimize_for_inference_lib",
    186         "//tensorflow/core:protos_all_py",
    187         "//tensorflow/python:array_ops",
    188         "//tensorflow/python:client_testlib",
    189         "//tensorflow/python:framework",
    190         "//tensorflow/python:framework_for_generated_wrappers",
    191         "//tensorflow/python:image_ops",
    192         "//tensorflow/python:math_ops",
    193         "//tensorflow/python:nn_ops",
    194         "//tensorflow/python:nn_ops_gen",
    195         "//third_party/py/numpy",
    196     ],
    197 )
    198 
    199 py_library(
    200     name = "selective_registration_header_lib",
    201     srcs = ["selective_registration_header_lib.py"],
    202     srcs_version = "PY2AND3",
    203     visibility = ["//visibility:public"],
    204     deps = [
    205         "//tensorflow/python",  # TODO(b/34059704): remove when fixed
    206         "//tensorflow/python:platform",
    207     ],
    208 )
    209 
    210 py_binary(
    211     name = "print_selective_registration_header",
    212     srcs = ["print_selective_registration_header.py"],
    213     srcs_version = "PY2AND3",
    214     visibility = ["//visibility:public"],
    215     deps = [
    216         ":selective_registration_header_lib",
    217         "//tensorflow/python:platform",
    218     ],
    219 )
    220 
    221 py_test(
    222     name = "print_selective_registration_header_test",
    223     srcs = ["print_selective_registration_header_test.py"],
    224     srcs_version = "PY2AND3",
    225     deps = [
    226         ":selective_registration_header_lib",
    227         "//tensorflow/python:client_testlib",
    228         "//tensorflow/python:platform",
    229     ],
    230 )
    231 
    232 py_binary(
    233     name = "saved_model_cli",
    234     srcs = ["saved_model_cli.py"],
    235     srcs_version = "PY2AND3",
    236     deps = [
    237         ":saved_model_utils",
    238         "//tensorflow/contrib/saved_model:saved_model_py",
    239         "//tensorflow/python",
    240         "//tensorflow/python/debug:local_cli_wrapper",
    241     ],
    242 )
    243 
    244 py_test(
    245     name = "saved_model_cli_test",
    246     srcs = ["saved_model_cli_test.py"],
    247     data = [
    248         "//tensorflow/cc/saved_model:saved_model_half_plus_two",
    249     ],
    250     srcs_version = "PY2AND3",
    251     tags = ["manual"],
    252     deps = [
    253         ":saved_model_cli",
    254         "//tensorflow/core:protos_all_py",
    255     ],
    256 )
    257 
    258 filegroup(
    259     name = "all_files",
    260     srcs = glob(
    261         ["**/*"],
    262         exclude = [
    263             "**/METADATA",
    264             "**/OWNERS",
    265             "bin/**",
    266             "gen/**",
    267         ],
    268     ),
    269     visibility = ["//tensorflow:__subpackages__"],
    270 )
    271