1 # Description: 2 # Wrap NVIDIA TensorRT (http://developer.nvidia.com/tensorrt) with tensorflow 3 # and provide TensorRT operators and converter package. 4 # APIs are meant to change over time. 5 6 package(default_visibility = ["//tensorflow:__subpackages__"]) 7 8 licenses(["notice"]) # Apache 2.0 9 10 exports_files(["LICENSE"]) 11 12 load( 13 "//tensorflow:tensorflow.bzl", 14 "tf_cc_test", 15 "tf_copts", 16 "tf_cuda_library", 17 "tf_custom_op_library", 18 "tf_custom_op_library_additional_deps", 19 "tf_gen_op_libs", 20 "tf_gen_op_wrapper_py", 21 ) 22 load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") 23 load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") 24 load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") 25 load( 26 "@local_config_tensorrt//:build_defs.bzl", 27 "if_tensorrt", 28 ) 29 30 tf_cuda_cc_test( 31 name = "tensorrt_test_cc", 32 size = "small", 33 srcs = ["tensorrt_test.cc"], 34 tags = [ 35 "manual", 36 "notap", 37 ], 38 deps = [ 39 "//tensorflow/core:lib", 40 "//tensorflow/core:test", 41 "//tensorflow/core:test_main", 42 ] + if_tensorrt([ 43 "@local_config_cuda//cuda:cuda_headers", 44 "@local_config_tensorrt//:nv_infer", 45 ]), 46 ) 47 48 tf_custom_op_library( 49 name = "python/ops/_trt_engine_op.so", 50 srcs = ["ops/trt_engine_op.cc"], 51 deps = [ 52 ":trt_engine_op_kernel", 53 ":trt_shape_function", 54 "//tensorflow/core:lib_proto_parsing", 55 ] + if_tensorrt([ 56 "@local_config_tensorrt//:nv_infer", 57 ]), 58 ) 59 60 tf_cuda_library( 61 name = "trt_shape_function", 62 srcs = ["shape_fn/trt_shfn.cc"], 63 hdrs = ["shape_fn/trt_shfn.h"], 64 visibility = ["//visibility:public"], 65 deps = [ 66 ":trt_logging", 67 ] + if_tensorrt([ 68 "@local_config_tensorrt//:nv_infer", 69 ]) + tf_custom_op_library_additional_deps(), 70 ) 71 72 cc_library( 73 name = "trt_engine_op_kernel", 74 srcs = ["kernels/trt_engine_op.cc"], 75 hdrs = ["kernels/trt_engine_op.h"], 76 copts = tf_copts(), 77 deps = [ 78 ":trt_logging", 79 "//tensorflow/core:gpu_headers_lib", 80 "//tensorflow/core:lib_proto_parsing", 81 "//tensorflow/core:stream_executor_headers_lib", 82 ] + if_tensorrt([ 83 "@local_config_tensorrt//:nv_infer", 84 ]) + tf_custom_op_library_additional_deps(), 85 alwayslink = 1, 86 ) 87 88 tf_gen_op_libs( 89 op_lib_names = ["trt_engine_op"], 90 deps = if_tensorrt([ 91 "@local_config_tensorrt//:nv_infer", 92 ]), 93 ) 94 95 tf_cuda_library( 96 name = "trt_logging", 97 srcs = ["log/trt_logger.cc"], 98 hdrs = ["log/trt_logger.h"], 99 visibility = ["//visibility:public"], 100 deps = [ 101 "//tensorflow/core:lib_proto_parsing", 102 ] + if_tensorrt([ 103 "@local_config_tensorrt//:nv_infer", 104 ]), 105 ) 106 107 tf_gen_op_wrapper_py( 108 name = "trt_engine_op", 109 deps = [ 110 ":trt_engine_op_op_lib", 111 ":trt_logging", 112 ":trt_shape_function", 113 ], 114 ) 115 116 tf_custom_op_py_library( 117 name = "trt_engine_op_loader", 118 srcs = ["python/ops/trt_engine_op.py"], 119 dso = [ 120 ":python/ops/_trt_engine_op.so", 121 ] + if_tensorrt([ 122 "@local_config_tensorrt//:nv_infer", 123 ]), 124 srcs_version = "PY2AND3", 125 deps = [ 126 "//tensorflow/python:framework_for_generated_wrappers", 127 "//tensorflow/python:resources", 128 ], 129 ) 130 131 py_library( 132 name = "init_py", 133 srcs = [ 134 "__init__.py", 135 "python/__init__.py", 136 ], 137 srcs_version = "PY2AND3", 138 deps = [ 139 ":trt_convert_py", 140 ":trt_ops_py", 141 ], 142 ) 143 144 py_library( 145 name = "trt_ops_py", 146 srcs_version = "PY2AND3", 147 deps = [ 148 ":trt_engine_op", 149 ":trt_engine_op_loader", 150 ], 151 ) 152 153 py_library( 154 name = "trt_convert_py", 155 srcs = ["python/trt_convert.py"], 156 srcs_version = "PY2AND3", 157 deps = [ 158 ":wrap_conversion", 159 ], 160 ) 161 162 tf_py_wrap_cc( 163 name = "wrap_conversion", 164 srcs = ["trt_conversion.i"], 165 copts = tf_copts(), 166 deps = [ 167 ":trt_conversion", 168 "//tensorflow/core:framework_lite", 169 "//util/python:python_headers", 170 ], 171 ) 172 173 # Library for the node-level conversion portion of TensorRT operation creation 174 tf_cuda_library( 175 name = "trt_conversion", 176 srcs = [ 177 "convert/convert_graph.cc", 178 "convert/convert_nodes.cc", 179 ], 180 hdrs = [ 181 "convert/convert_graph.h", 182 "convert/convert_nodes.h", 183 ], 184 deps = [ 185 ":segment", 186 ":trt_logging", 187 "//tensorflow/core/grappler:grappler_item", 188 "//tensorflow/core/grappler:utils", 189 "//tensorflow/core:framework", 190 "//tensorflow/core:framework_lite", 191 "//tensorflow/core:graph", 192 "//tensorflow/core:lib", 193 "//tensorflow/core:lib_internal", 194 "//tensorflow/core:protos_all_cc", 195 "//tensorflow/core/grappler:devices", 196 "//tensorflow/core/grappler/clusters:virtual_cluster", 197 "//tensorflow/core/grappler/costs:graph_properties", 198 "//tensorflow/core/grappler/optimizers:constant_folding", 199 "//tensorflow/core/grappler/optimizers:layout_optimizer", 200 ] + if_tensorrt([ 201 "@local_config_tensorrt//:nv_infer", 202 ]) + tf_custom_op_library_additional_deps(), 203 ) 204 205 # Library for the segmenting portion of TensorRT operation creation 206 cc_library( 207 name = "segment", 208 srcs = ["segment/segment.cc"], 209 hdrs = [ 210 "segment/segment.h", 211 "segment/union_find.h", 212 ], 213 linkstatic = 1, 214 deps = [ 215 "//tensorflow/core:graph", 216 "//tensorflow/core:lib_proto_parsing", 217 "//tensorflow/core:protos_all_cc", 218 "@protobuf_archive//:protobuf_headers", 219 ], 220 ) 221 222 tf_cc_test( 223 name = "segment_test", 224 size = "small", 225 srcs = ["segment/segment_test.cc"], 226 deps = [ 227 ":segment", 228 "//tensorflow/c:c_api", 229 "//tensorflow/core:lib", 230 "//tensorflow/core:protos_all_cc", 231 "//tensorflow/core:test", 232 "//tensorflow/core:test_main", 233 ], 234 ) 235 236 filegroup( 237 name = "all_files", 238 srcs = glob( 239 ["**/*"], 240 exclude = [ 241 "**/METADATA", 242 "**/OWNERS", 243 ], 244 ), 245 visibility = ["//tensorflow:__subpackages__"], 246 ) 247