1 # `Predictor` classes provide an interface for efficient, repeated inference. 2 3 package(default_visibility = ["//tensorflow/contrib/predictor:__subpackages__"]) 4 5 licenses(["notice"]) # Apache 2.0 6 7 exports_files(["LICENSE"]) 8 9 load("//tensorflow:tensorflow.bzl", "py_test") 10 11 py_library( 12 name = "predictor", 13 srcs = ["__init__.py"], 14 srcs_version = "PY2AND3", 15 visibility = ["//visibility:public"], 16 deps = [ 17 ":predictor_factories", 18 "//tensorflow/python:util", 19 ], 20 ) 21 22 py_library( 23 name = "predictor_factories", 24 srcs = ["predictor_factories.py"], 25 srcs_version = "PY2AND3", 26 deps = [ 27 ":contrib_estimator_predictor", 28 ":core_estimator_predictor", 29 ":saved_model_predictor", 30 "//tensorflow/python/estimator:estimator_py", 31 ], 32 ) 33 34 py_library( 35 name = "base_predictor", 36 srcs = ["predictor.py"], 37 srcs_version = "PY2AND3", 38 deps = ["@six_archive//:six"], 39 ) 40 41 py_library( 42 name = "saved_model_predictor", 43 srcs = ["saved_model_predictor.py"], 44 srcs_version = "PY2AND3", 45 visibility = ["//learning/brain/contrib/learn/tpu:__subpackages__"], 46 deps = [ 47 ":base_predictor", 48 "//tensorflow/contrib/saved_model:saved_model_py", 49 "//tensorflow/python:framework_ops", 50 "//tensorflow/python:session", 51 "//tensorflow/python/saved_model:loader", 52 "//tensorflow/python/saved_model:signature_constants", 53 ], 54 ) 55 56 py_library( 57 name = "core_estimator_predictor", 58 srcs = ["core_estimator_predictor.py"], 59 srcs_version = "PY2AND3", 60 deps = [ 61 ":base_predictor", 62 "//tensorflow/python:framework_ops", 63 "//tensorflow/python:training", 64 "//tensorflow/python/estimator:estimator_py", 65 "//tensorflow/python/saved_model:signature_constants", 66 ], 67 ) 68 69 py_library( 70 name = "contrib_estimator_predictor", 71 srcs = ["contrib_estimator_predictor.py"], 72 srcs_version = "PY2AND3", 73 deps = [ 74 ":base_predictor", 75 "//tensorflow/contrib/learn", 76 "//tensorflow/python:framework_ops", 77 "//tensorflow/python:training", 78 ], 79 ) 80 81 py_library( 82 name = "testing_common", 83 srcs = ["testing_common.py"], 84 srcs_version = "PY2AND3", 85 tags = ["no_pip"], 86 deps = [ 87 "//tensorflow/contrib/learn", 88 "//tensorflow/python:array_ops", 89 "//tensorflow/python:constant_op", 90 "//tensorflow/python:control_flow_ops", 91 "//tensorflow/python:framework_ops", 92 "//tensorflow/python:math_ops", 93 "//tensorflow/python/estimator:estimator_py", 94 "//tensorflow/python/saved_model:signature_constants", 95 ], 96 ) 97 98 # Transitive dependencies of this target will be included in the pip package. 99 py_library( 100 name = "predictor_pip", 101 visibility = ["//visibility:public"], 102 deps = [ 103 ":contrib_estimator_predictor", 104 ":core_estimator_predictor", 105 ":saved_model_predictor", 106 ], 107 ) 108 109 py_test( 110 name = "saved_model_predictor_test", 111 srcs = ["saved_model_predictor_test.py"], 112 data = [":test_export_dir"], 113 srcs_version = "PY2AND3", 114 tags = ["no_pip"], 115 deps = [ 116 ":saved_model_predictor", 117 "//tensorflow/core:protos_all_py", 118 "//tensorflow/python:client_testlib", 119 "//tensorflow/python:framework_ops", 120 "//tensorflow/python/saved_model:signature_def_utils", 121 "//third_party/py/numpy", 122 ], 123 ) 124 125 py_test( 126 name = "predictor_factories_test", 127 srcs = ["predictor_factories_test.py"], 128 data = [":test_export_dir"], 129 srcs_version = "PY2AND3", 130 tags = ["no_pip"], 131 deps = [ 132 ":predictor_factories", 133 ":testing_common", 134 ], 135 ) 136 137 py_test( 138 name = "core_estimator_predictor_test", 139 srcs = ["core_estimator_predictor_test.py"], 140 srcs_version = "PY2AND3", 141 tags = ["no_pip"], 142 deps = [ 143 ":core_estimator_predictor", 144 ":testing_common", 145 "//tensorflow/python:client_testlib", 146 "//third_party/py/numpy", 147 ], 148 ) 149 150 py_test( 151 name = "contrib_estimator_predictor_test", 152 srcs = ["contrib_estimator_predictor_test.py"], 153 srcs_version = "PY2AND3", 154 tags = ["no_pip"], 155 deps = [ 156 ":contrib_estimator_predictor", 157 ":testing_common", 158 "//tensorflow/python:client_testlib", 159 "//third_party/py/numpy", 160 ], 161 ) 162 163 filegroup( 164 name = "test_export_dir", 165 srcs = glob(["test_export_dir/**/*"]), 166 tags = ["no_pip"], 167 ) 168