Home | History | Annotate | Download | only in tensorrt
      1 # Description:
      2 #   Wrap NVIDIA TensorRT (http://developer.nvidia.com/tensorrt) with tensorflow
      3 #   and provide TensorRT operators and converter package.
      4 #   APIs are meant to change over time.
      5 
      6 package(default_visibility = ["//tensorflow:__subpackages__"])
      7 
      8 licenses(["notice"])  # Apache 2.0
      9 
     10 exports_files(["LICENSE"])
     11 
     12 load(
     13     "//tensorflow:tensorflow.bzl",
     14     "tf_cc_test",
     15     "tf_copts",
     16     "tf_cuda_library",
     17     "tf_custom_op_library",
     18     "tf_custom_op_library_additional_deps",
     19     "tf_gen_op_libs",
     20     "tf_gen_op_wrapper_py",
     21 )
     22 load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
     23 load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
     24 load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
     25 load(
     26     "@local_config_tensorrt//:build_defs.bzl",
     27     "if_tensorrt",
     28 )
     29 
     30 tf_cuda_cc_test(
     31     name = "tensorrt_test_cc",
     32     size = "small",
     33     srcs = ["tensorrt_test.cc"],
     34     tags = [
     35         "manual",
     36         "notap",
     37     ],
     38     deps = [
     39         "//tensorflow/core:lib",
     40         "//tensorflow/core:test",
     41         "//tensorflow/core:test_main",
     42     ] + if_tensorrt([
     43         "@local_config_cuda//cuda:cuda_headers",
     44         "@local_config_tensorrt//:nv_infer",
     45     ]),
     46 )
     47 
     48 tf_custom_op_library(
     49     name = "python/ops/_trt_engine_op.so",
     50     srcs = ["ops/trt_engine_op.cc"],
     51     deps = [
     52         ":trt_engine_op_kernel",
     53         ":trt_shape_function",
     54         "//tensorflow/core:lib_proto_parsing",
     55     ] + if_tensorrt([
     56         "@local_config_tensorrt//:nv_infer",
     57     ]),
     58 )
     59 
     60 tf_cuda_library(
     61     name = "trt_shape_function",
     62     srcs = ["shape_fn/trt_shfn.cc"],
     63     hdrs = ["shape_fn/trt_shfn.h"],
     64     visibility = ["//visibility:public"],
     65     deps = [
     66         ":trt_logging",
     67     ] + if_tensorrt([
     68         "@local_config_tensorrt//:nv_infer",
     69     ]) + tf_custom_op_library_additional_deps(),
     70 )
     71 
     72 cc_library(
     73     name = "trt_engine_op_kernel",
     74     srcs = ["kernels/trt_engine_op.cc"],
     75     hdrs = ["kernels/trt_engine_op.h"],
     76     copts = tf_copts(),
     77     deps = [
     78         ":trt_logging",
     79         "//tensorflow/core:gpu_headers_lib",
     80         "//tensorflow/core:lib_proto_parsing",
     81         "//tensorflow/core:stream_executor_headers_lib",
     82     ] + if_tensorrt([
     83         "@local_config_tensorrt//:nv_infer",
     84     ]) + tf_custom_op_library_additional_deps(),
     85     alwayslink = 1,
     86 )
     87 
     88 tf_gen_op_libs(
     89     op_lib_names = ["trt_engine_op"],
     90     deps = if_tensorrt([
     91         "@local_config_tensorrt//:nv_infer",
     92     ]),
     93 )
     94 
     95 tf_cuda_library(
     96     name = "trt_logging",
     97     srcs = ["log/trt_logger.cc"],
     98     hdrs = ["log/trt_logger.h"],
     99     visibility = ["//visibility:public"],
    100     deps = [
    101         "//tensorflow/core:lib_proto_parsing",
    102     ] + if_tensorrt([
    103         "@local_config_tensorrt//:nv_infer",
    104     ]),
    105 )
    106 
    107 tf_gen_op_wrapper_py(
    108     name = "trt_engine_op",
    109     deps = [
    110         ":trt_engine_op_op_lib",
    111         ":trt_logging",
    112         ":trt_shape_function",
    113     ],
    114 )
    115 
    116 tf_custom_op_py_library(
    117     name = "trt_engine_op_loader",
    118     srcs = ["python/ops/trt_engine_op.py"],
    119     dso = [
    120         ":python/ops/_trt_engine_op.so",
    121     ] + if_tensorrt([
    122         "@local_config_tensorrt//:nv_infer",
    123     ]),
    124     srcs_version = "PY2AND3",
    125     deps = [
    126         "//tensorflow/python:framework_for_generated_wrappers",
    127         "//tensorflow/python:resources",
    128     ],
    129 )
    130 
    131 py_library(
    132     name = "init_py",
    133     srcs = [
    134         "__init__.py",
    135         "python/__init__.py",
    136     ],
    137     srcs_version = "PY2AND3",
    138     deps = [
    139         ":trt_convert_py",
    140         ":trt_ops_py",
    141     ],
    142 )
    143 
    144 py_library(
    145     name = "trt_ops_py",
    146     srcs_version = "PY2AND3",
    147     deps = [
    148         ":trt_engine_op",
    149         ":trt_engine_op_loader",
    150     ],
    151 )
    152 
    153 py_library(
    154     name = "trt_convert_py",
    155     srcs = ["python/trt_convert.py"],
    156     srcs_version = "PY2AND3",
    157     deps = [
    158         ":wrap_conversion",
    159     ],
    160 )
    161 
    162 tf_py_wrap_cc(
    163     name = "wrap_conversion",
    164     srcs = ["trt_conversion.i"],
    165     copts = tf_copts(),
    166     deps = [
    167         ":trt_conversion",
    168         "//tensorflow/core:framework_lite",
    169         "//util/python:python_headers",
    170     ],
    171 )
    172 
    173 # Library for the node-level conversion portion of TensorRT operation creation
    174 tf_cuda_library(
    175     name = "trt_conversion",
    176     srcs = [
    177         "convert/convert_graph.cc",
    178         "convert/convert_nodes.cc",
    179     ],
    180     hdrs = [
    181         "convert/convert_graph.h",
    182         "convert/convert_nodes.h",
    183     ],
    184     deps = [
    185         ":segment",
    186         ":trt_logging",
    187         "//tensorflow/core/grappler:grappler_item",
    188         "//tensorflow/core/grappler:utils",
    189         "//tensorflow/core:framework",
    190         "//tensorflow/core:framework_lite",
    191         "//tensorflow/core:graph",
    192         "//tensorflow/core:lib",
    193         "//tensorflow/core:lib_internal",
    194         "//tensorflow/core:protos_all_cc",
    195         "//tensorflow/core/grappler:devices",
    196         "//tensorflow/core/grappler/clusters:virtual_cluster",
    197         "//tensorflow/core/grappler/costs:graph_properties",
    198         "//tensorflow/core/grappler/optimizers:constant_folding",
    199         "//tensorflow/core/grappler/optimizers:layout_optimizer",
    200     ] + if_tensorrt([
    201         "@local_config_tensorrt//:nv_infer",
    202     ]) + tf_custom_op_library_additional_deps(),
    203 )
    204 
    205 # Library for the segmenting portion of TensorRT operation creation
    206 cc_library(
    207     name = "segment",
    208     srcs = ["segment/segment.cc"],
    209     hdrs = [
    210         "segment/segment.h",
    211         "segment/union_find.h",
    212     ],
    213     linkstatic = 1,
    214     deps = [
    215         "//tensorflow/core:graph",
    216         "//tensorflow/core:lib_proto_parsing",
    217         "//tensorflow/core:protos_all_cc",
    218         "@protobuf_archive//:protobuf_headers",
    219     ],
    220 )
    221 
    222 tf_cc_test(
    223     name = "segment_test",
    224     size = "small",
    225     srcs = ["segment/segment_test.cc"],
    226     deps = [
    227         ":segment",
    228         "//tensorflow/c:c_api",
    229         "//tensorflow/core:lib",
    230         "//tensorflow/core:protos_all_cc",
    231         "//tensorflow/core:test",
    232         "//tensorflow/core:test_main",
    233     ],
    234 )
    235 
    236 filegroup(
    237     name = "all_files",
    238     srcs = glob(
    239         ["**/*"],
    240         exclude = [
    241             "**/METADATA",
    242             "**/OWNERS",
    243         ],
    244     ),
    245     visibility = ["//tensorflow:__subpackages__"],
    246 )
    247