Home | History | Annotate | Download | only in predictor
      1 # `Predictor` classes provide an interface for efficient, repeated inference.
      2 
      3 package(default_visibility = ["//tensorflow/contrib/predictor:__subpackages__"])
      4 
      5 licenses(["notice"])  # Apache 2.0
      6 
      7 exports_files(["LICENSE"])
      8 
      9 load("//tensorflow:tensorflow.bzl", "py_test")
     10 
     11 py_library(
     12     name = "predictor",
     13     srcs = ["__init__.py"],
     14     srcs_version = "PY2AND3",
     15     visibility = ["//visibility:public"],
     16     deps = [
     17         ":predictor_factories",
     18         "//tensorflow/python:util",
     19     ],
     20 )
     21 
     22 py_library(
     23     name = "predictor_factories",
     24     srcs = ["predictor_factories.py"],
     25     srcs_version = "PY2AND3",
     26     deps = [
     27         ":contrib_estimator_predictor",
     28         ":core_estimator_predictor",
     29         ":saved_model_predictor",
     30         "//tensorflow/python/estimator:estimator_py",
     31     ],
     32 )
     33 
     34 py_library(
     35     name = "base_predictor",
     36     srcs = ["predictor.py"],
     37     srcs_version = "PY2AND3",
     38     deps = ["@six_archive//:six"],
     39 )
     40 
     41 py_library(
     42     name = "saved_model_predictor",
     43     srcs = ["saved_model_predictor.py"],
     44     srcs_version = "PY2AND3",
     45     visibility = ["//learning/brain/contrib/learn/tpu:__subpackages__"],
     46     deps = [
     47         ":base_predictor",
     48         "//tensorflow/contrib/saved_model:saved_model_py",
     49         "//tensorflow/python:framework_ops",
     50         "//tensorflow/python:session",
     51         "//tensorflow/python/saved_model:loader",
     52         "//tensorflow/python/saved_model:signature_constants",
     53     ],
     54 )
     55 
     56 py_library(
     57     name = "core_estimator_predictor",
     58     srcs = ["core_estimator_predictor.py"],
     59     srcs_version = "PY2AND3",
     60     deps = [
     61         ":base_predictor",
     62         "//tensorflow/python:framework_ops",
     63         "//tensorflow/python:training",
     64         "//tensorflow/python/estimator:estimator_py",
     65         "//tensorflow/python/saved_model:signature_constants",
     66     ],
     67 )
     68 
     69 py_library(
     70     name = "contrib_estimator_predictor",
     71     srcs = ["contrib_estimator_predictor.py"],
     72     srcs_version = "PY2AND3",
     73     deps = [
     74         ":base_predictor",
     75         "//tensorflow/contrib/learn",
     76         "//tensorflow/python:framework_ops",
     77         "//tensorflow/python:training",
     78     ],
     79 )
     80 
     81 py_library(
     82     name = "testing_common",
     83     srcs = ["testing_common.py"],
     84     srcs_version = "PY2AND3",
     85     tags = ["no_pip"],
     86     deps = [
     87         "//tensorflow/contrib/learn",
     88         "//tensorflow/python:array_ops",
     89         "//tensorflow/python:constant_op",
     90         "//tensorflow/python:control_flow_ops",
     91         "//tensorflow/python:framework_ops",
     92         "//tensorflow/python:math_ops",
     93         "//tensorflow/python/estimator:estimator_py",
     94         "//tensorflow/python/saved_model:signature_constants",
     95     ],
     96 )
     97 
     98 # Transitive dependencies of this target will be included in the pip package.
     99 py_library(
    100     name = "predictor_pip",
    101     visibility = ["//visibility:public"],
    102     deps = [
    103         ":contrib_estimator_predictor",
    104         ":core_estimator_predictor",
    105         ":saved_model_predictor",
    106     ],
    107 )
    108 
    109 py_test(
    110     name = "saved_model_predictor_test",
    111     srcs = ["saved_model_predictor_test.py"],
    112     data = [":test_export_dir"],
    113     srcs_version = "PY2AND3",
    114     tags = ["no_pip"],
    115     deps = [
    116         ":saved_model_predictor",
    117         "//tensorflow/core:protos_all_py",
    118         "//tensorflow/python:client_testlib",
    119         "//tensorflow/python:framework_ops",
    120         "//tensorflow/python/saved_model:signature_def_utils",
    121         "//third_party/py/numpy",
    122     ],
    123 )
    124 
    125 py_test(
    126     name = "predictor_factories_test",
    127     srcs = ["predictor_factories_test.py"],
    128     data = [":test_export_dir"],
    129     srcs_version = "PY2AND3",
    130     tags = ["no_pip"],
    131     deps = [
    132         ":predictor_factories",
    133         ":testing_common",
    134     ],
    135 )
    136 
    137 py_test(
    138     name = "core_estimator_predictor_test",
    139     srcs = ["core_estimator_predictor_test.py"],
    140     srcs_version = "PY2AND3",
    141     tags = ["no_pip"],
    142     deps = [
    143         ":core_estimator_predictor",
    144         ":testing_common",
    145         "//tensorflow/python:client_testlib",
    146         "//third_party/py/numpy",
    147     ],
    148 )
    149 
    150 py_test(
    151     name = "contrib_estimator_predictor_test",
    152     srcs = ["contrib_estimator_predictor_test.py"],
    153     srcs_version = "PY2AND3",
    154     tags = ["no_pip"],
    155     deps = [
    156         ":contrib_estimator_predictor",
    157         ":testing_common",
    158         "//tensorflow/python:client_testlib",
    159         "//third_party/py/numpy",
    160     ],
    161 )
    162 
    163 filegroup(
    164     name = "test_export_dir",
    165     srcs = glob(["test_export_dir/**/*"]),
    166     tags = ["no_pip"],
    167 )
    168