Home | History | Annotate | Download | only in gan
      1 # Files for using TFGAN framework.
      2 package(default_visibility = ["//tensorflow:__subpackages__"])
      3 
      4 licenses(["notice"])  # Apache 2.0
      5 
      6 exports_files(["LICENSE"])
      7 
      8 load("//tensorflow:tensorflow.bzl", "py_test")
      9 
     10 py_library(
     11     name = "gan",
     12     srcs = [
     13         "__init__.py",
     14     ],
     15     srcs_version = "PY2AND3",
     16     deps = [
     17         ":estimator",
     18         ":eval",
     19         ":features",
     20         ":losses",
     21         ":namedtuples",
     22         ":train",
     23         "//tensorflow/python:util",
     24     ],
     25 )
     26 
     27 py_library(
     28     name = "namedtuples",
     29     srcs = ["python/namedtuples.py"],
     30     srcs_version = "PY2AND3",
     31 )
     32 
     33 py_library(
     34     name = "train",
     35     srcs = ["python/train.py"],
     36     srcs_version = "PY2AND3",
     37     deps = [
     38         ":losses",
     39         ":namedtuples",
     40         "//tensorflow/contrib/framework:framework_py",
     41         "//tensorflow/contrib/slim:learning",
     42         "//tensorflow/contrib/training:training_py",
     43         "//tensorflow/python:array_ops",
     44         "//tensorflow/python:check_ops",
     45         "//tensorflow/python:framework_ops",
     46         "//tensorflow/python:init_ops",
     47         "//tensorflow/python:training",
     48         "//tensorflow/python:variable_scope",
     49         "//tensorflow/python/ops/distributions",
     50         "//tensorflow/python/ops/losses",
     51     ],
     52 )
     53 
     54 py_test(
     55     name = "train_test",
     56     srcs = ["python/train_test.py"],
     57     srcs_version = "PY2AND3",
     58     tags = ["notsan"],
     59     deps = [
     60         ":features",
     61         ":namedtuples",
     62         ":train",
     63         "//tensorflow/contrib/framework:framework_py",
     64         "//tensorflow/contrib/slim:learning",
     65         "//tensorflow/python:array_ops",
     66         "//tensorflow/python:client_testlib",
     67         "//tensorflow/python:constant_op",
     68         "//tensorflow/python:dtypes",
     69         "//tensorflow/python:framework_ops",
     70         "//tensorflow/python:random_ops",
     71         "//tensorflow/python:random_seed",
     72         "//tensorflow/python:training",
     73         "//tensorflow/python:variable_scope",
     74         "//tensorflow/python:variables",
     75         "//tensorflow/python/ops/distributions",
     76         "//third_party/py/numpy",
     77     ],
     78 )
     79 
     80 py_library(
     81     name = "eval",
     82     srcs = ["python/eval/__init__.py"],
     83     srcs_version = "PY2AND3",
     84     deps = [
     85         ":classifier_metrics",
     86         ":eval_utils",
     87         ":sliced_wasserstein",
     88         ":summaries",
     89         "//tensorflow/python:util",
     90     ],
     91 )
     92 
     93 py_library(
     94     name = "estimator",
     95     srcs = ["python/estimator/__init__.py"],
     96     srcs_version = "PY2AND3",
     97     deps = [
     98         ":gan_estimator",
     99         ":head",
    100         "//tensorflow/python:util",
    101     ],
    102 )
    103 
    104 py_library(
    105     name = "losses",
    106     srcs = ["python/losses/__init__.py"],
    107     srcs_version = "PY2AND3",
    108     deps = [
    109         ":losses_impl",
    110         ":tuple_losses",
    111         "//tensorflow/python:util",
    112     ],
    113 )
    114 
    115 py_library(
    116     name = "features",
    117     srcs = ["python/features/__init__.py"],
    118     srcs_version = "PY2AND3",
    119     deps = [
    120         ":clip_weights",
    121         ":conditioning_utils",
    122         ":random_tensor_pool",
    123         ":virtual_batchnorm",
    124         "//tensorflow/python:util",
    125     ],
    126 )
    127 
    128 py_library(
    129     name = "losses_impl",
    130     srcs = ["python/losses/python/losses_impl.py"],
    131     srcs_version = "PY2AND3",
    132     deps = [
    133         "//tensorflow/contrib/framework:framework_py",
    134         "//tensorflow/python:array_ops",
    135         "//tensorflow/python:clip_ops",
    136         "//tensorflow/python:framework_ops",
    137         "//tensorflow/python:gradients",
    138         "//tensorflow/python:math_ops",
    139         "//tensorflow/python:random_ops",
    140         "//tensorflow/python:summary",
    141         "//tensorflow/python:tensor_util",
    142         "//tensorflow/python:variable_scope",
    143         "//tensorflow/python/ops/distributions",
    144         "//tensorflow/python/ops/losses",
    145         "//third_party/py/numpy",
    146     ],
    147 )
    148 
    149 py_test(
    150     name = "losses_impl_test",
    151     srcs = ["python/losses/python/losses_impl_test.py"],
    152     srcs_version = "PY2AND3",
    153     deps = [
    154         ":losses_impl",
    155         "//tensorflow/python:array_ops",
    156         "//tensorflow/python:client_testlib",
    157         "//tensorflow/python:clip_ops",
    158         "//tensorflow/python:constant_op",
    159         "//tensorflow/python:dtypes",
    160         "//tensorflow/python:framework_ops",
    161         "//tensorflow/python:math_ops",
    162         "//tensorflow/python:random_ops",
    163         "//tensorflow/python:random_seed",
    164         "//tensorflow/python:variable_scope",
    165         "//tensorflow/python:variables",
    166         "//tensorflow/python/ops/distributions",
    167         "//tensorflow/python/ops/losses",
    168     ],
    169 )
    170 
    171 py_library(
    172     name = "tuple_losses",
    173     srcs = [
    174         "python/losses/python/losses_wargs.py",
    175         "python/losses/python/tuple_losses.py",
    176         "python/losses/python/tuple_losses_impl.py",
    177     ],
    178     srcs_version = "PY2AND3",
    179     deps = [
    180         ":losses_impl",
    181         ":namedtuples",
    182         "//tensorflow/python:util",
    183     ],
    184 )
    185 
    186 py_test(
    187     name = "tuple_losses_test",
    188     srcs = ["python/losses/python/tuple_losses_test.py"],
    189     srcs_version = "PY2AND3",
    190     deps = [
    191         ":tuple_losses",
    192         "//tensorflow/python:client_testlib",
    193         "//tensorflow/python:constant_op",
    194         "//tensorflow/python:dtypes",
    195         "//tensorflow/python:variables",
    196         "//third_party/py/numpy",
    197     ],
    198 )
    199 
    200 py_library(
    201     name = "conditioning_utils",
    202     srcs = [
    203         "python/features/python/conditioning_utils.py",
    204         "python/features/python/conditioning_utils_impl.py",
    205     ],
    206     srcs_version = "PY2AND3",
    207     deps = [
    208         "//tensorflow/contrib/layers:layers_py",
    209         "//tensorflow/python:array_ops",
    210         "//tensorflow/python:embedding_ops",
    211         "//tensorflow/python:math_ops",
    212         "//tensorflow/python:tensor_util",
    213         "//tensorflow/python:util",
    214         "//tensorflow/python:variable_scope",
    215     ],
    216 )
    217 
    218 py_test(
    219     name = "conditioning_utils_test",
    220     srcs = ["python/features/python/conditioning_utils_test.py"],
    221     srcs_version = "PY2AND3",
    222     deps = [
    223         ":conditioning_utils",
    224         "//tensorflow/python:array_ops",
    225         "//tensorflow/python:client_testlib",
    226         "//tensorflow/python:dtypes",
    227     ],
    228 )
    229 
    230 py_library(
    231     name = "random_tensor_pool",
    232     srcs = [
    233         "python/features/python/random_tensor_pool.py",
    234         "python/features/python/random_tensor_pool_impl.py",
    235     ],
    236     srcs_version = "PY2AND3",
    237     deps = [
    238         "//tensorflow/python:array_ops",
    239         "//tensorflow/python:control_flow_ops",
    240         "//tensorflow/python:data_flow_ops",
    241         "//tensorflow/python:dtypes",
    242         "//tensorflow/python:framework_ops",
    243         "//tensorflow/python:random_ops",
    244         "//tensorflow/python:util",
    245     ],
    246 )
    247 
    248 py_test(
    249     name = "random_tensor_pool_test",
    250     srcs = ["python/features/python/random_tensor_pool_test.py"],
    251     srcs_version = "PY2AND3",
    252     deps = [
    253         ":random_tensor_pool",
    254         "//tensorflow/python:array_ops",
    255         "//tensorflow/python:client_testlib",
    256         "//tensorflow/python:dtypes",
    257         "//third_party/py/numpy",
    258     ],
    259 )
    260 
    261 py_library(
    262     name = "virtual_batchnorm",
    263     srcs = [
    264         "python/features/python/virtual_batchnorm.py",
    265         "python/features/python/virtual_batchnorm_impl.py",
    266     ],
    267     srcs_version = "PY2AND3",
    268     deps = [
    269         "//tensorflow/python:array_ops",
    270         "//tensorflow/python:dtypes",
    271         "//tensorflow/python:framework_ops",
    272         "//tensorflow/python:init_ops",
    273         "//tensorflow/python:math_ops",
    274         "//tensorflow/python:nn",
    275         "//tensorflow/python:tensor_shape",
    276         "//tensorflow/python:tensor_util",
    277         "//tensorflow/python:util",
    278         "//tensorflow/python:variable_scope",
    279     ],
    280 )
    281 
    282 py_test(
    283     name = "virtual_batchnorm_test",
    284     srcs = ["python/features/python/virtual_batchnorm_test.py"],
    285     srcs_version = "PY2AND3",
    286     deps = [
    287         ":virtual_batchnorm",
    288         "//tensorflow/contrib/framework:framework_py",
    289         "//tensorflow/python:array_ops",
    290         "//tensorflow/python:client_testlib",
    291         "//tensorflow/python:constant_op",
    292         "//tensorflow/python:dtypes",
    293         "//tensorflow/python:layers",
    294         "//tensorflow/python:math_ops",
    295         "//tensorflow/python:nn",
    296         "//tensorflow/python:random_ops",
    297         "//tensorflow/python:random_seed",
    298         "//tensorflow/python:variable_scope",
    299         "//tensorflow/python:variables",
    300         "//third_party/py/numpy",
    301     ],
    302 )
    303 
    304 py_library(
    305     name = "clip_weights",
    306     srcs = [
    307         "python/features/python/clip_weights.py",
    308         "python/features/python/clip_weights_impl.py",
    309     ],
    310     srcs_version = "PY2AND3",
    311     deps = [
    312         "//tensorflow/contrib/opt:opt_py",
    313         "//tensorflow/python:util",
    314     ],
    315 )
    316 
    317 py_test(
    318     name = "clip_weights_test",
    319     srcs = ["python/features/python/clip_weights_test.py"],
    320     srcs_version = "PY2AND3",
    321     deps = [
    322         ":clip_weights",
    323         "//tensorflow/python:client_testlib",
    324         "//tensorflow/python:training",
    325         "//tensorflow/python:variables",
    326     ],
    327 )
    328 
    329 py_library(
    330     name = "classifier_metrics",
    331     srcs = [
    332         "python/eval/python/classifier_metrics.py",
    333         "python/eval/python/classifier_metrics_impl.py",
    334     ],
    335     srcs_version = "PY2AND3",
    336     deps = [
    337         "//tensorflow/contrib/layers:layers_py",
    338         "//tensorflow/core:protos_all_py",
    339         "//tensorflow/python:array_ops",
    340         "//tensorflow/python:dtypes",
    341         "//tensorflow/python:framework",
    342         "//tensorflow/python:framework_ops",
    343         "//tensorflow/python:functional_ops",
    344         "//tensorflow/python:image_ops",
    345         "//tensorflow/python:linalg_ops",
    346         "//tensorflow/python:math_ops",
    347         "//tensorflow/python:nn_ops",
    348         "//tensorflow/python:platform",
    349         "//tensorflow/python:util",
    350     ],
    351 )
    352 
    353 py_test(
    354     name = "classifier_metrics_test",
    355     srcs = ["python/eval/python/classifier_metrics_test.py"],
    356     srcs_version = "PY2AND3",
    357     deps = [
    358         ":classifier_metrics",
    359         "//tensorflow/core:protos_all_py",
    360         "//tensorflow/python:array_ops",
    361         "//tensorflow/python:client_testlib",
    362         "//tensorflow/python:dtypes",
    363         "//tensorflow/python:framework_ops",
    364         "//tensorflow/python:variables",
    365         "//third_party/py/numpy",
    366     ],
    367 )
    368 
    369 py_library(
    370     name = "eval_utils",
    371     srcs = [
    372         "python/eval/python/eval_utils.py",
    373         "python/eval/python/eval_utils_impl.py",
    374     ],
    375     srcs_version = "PY2AND3",
    376     deps = [
    377         "//tensorflow/python:array_ops",
    378         "//tensorflow/python:framework_ops",
    379         "//tensorflow/python:util",
    380     ],
    381 )
    382 
    383 py_test(
    384     name = "eval_utils_test",
    385     srcs = ["python/eval/python/eval_utils_test.py"],
    386     srcs_version = "PY2AND3",
    387     deps = [
    388         ":eval_utils",
    389         "//tensorflow/python:array_ops",
    390         "//tensorflow/python:client_testlib",
    391     ],
    392 )
    393 
    394 py_library(
    395     name = "summaries",
    396     srcs = [
    397         "python/eval/python/summaries.py",
    398         "python/eval/python/summaries_impl.py",
    399     ],
    400     srcs_version = "PY2AND3",
    401     deps = [
    402         ":eval_utils",
    403         ":namedtuples",
    404         "//tensorflow/python:array_ops",
    405         "//tensorflow/python:framework_ops",
    406         "//tensorflow/python:math_ops",
    407         "//tensorflow/python:summary",
    408         "//tensorflow/python:util",
    409         "//tensorflow/python/ops/losses",
    410     ],
    411 )
    412 
    413 py_test(
    414     name = "summaries_test",
    415     srcs = ["python/eval/python/summaries_test.py"],
    416     srcs_version = "PY2AND3",
    417     deps = [
    418         ":namedtuples",
    419         ":summaries",
    420         "//tensorflow/python:array_ops",
    421         "//tensorflow/python:client_testlib",
    422         "//tensorflow/python:framework_ops",
    423         "//tensorflow/python:summary",
    424         "//tensorflow/python:variable_scope",
    425         "//tensorflow/python:variables",
    426     ],
    427 )
    428 
    429 py_library(
    430     name = "head",
    431     srcs = [
    432         "python/estimator/python/head.py",
    433         "python/estimator/python/head_impl.py",
    434     ],
    435     srcs_version = "PY2AND3",
    436     deps = [
    437         ":namedtuples",
    438         ":train",
    439         "//tensorflow/python:framework_ops",
    440         "//tensorflow/python:util",
    441         "//tensorflow/python/estimator:head",
    442         "//tensorflow/python/estimator:model_fn",
    443     ],
    444 )
    445 
    446 py_test(
    447     name = "head_test",
    448     srcs = ["python/estimator/python/head_test.py"],
    449     shard_count = 1,
    450     srcs_version = "PY2AND3",
    451     deps = [
    452         ":head",
    453         ":namedtuples",
    454         "//tensorflow/python:array_ops",
    455         "//tensorflow/python:client_testlib",
    456         "//tensorflow/python:math_ops",
    457         "//tensorflow/python:training",
    458         "//tensorflow/python:variable_scope",
    459         "//tensorflow/python/estimator:model_fn",
    460     ],
    461 )
    462 
    463 py_library(
    464     name = "gan_estimator",
    465     srcs = [
    466         "python/estimator/python/gan_estimator.py",
    467         "python/estimator/python/gan_estimator_impl.py",
    468     ],
    469     srcs_version = "PY2AND3",
    470     deps = [
    471         ":head",
    472         ":namedtuples",
    473         ":summaries",
    474         ":train",
    475         "//tensorflow/contrib/framework:framework_py",
    476         "//tensorflow/python:framework_ops",
    477         "//tensorflow/python:util",
    478         "//tensorflow/python:variable_scope",
    479         "//tensorflow/python/estimator",
    480         "//tensorflow/python/estimator:model_fn",
    481     ],
    482 )
    483 
    484 py_test(
    485     name = "gan_estimator_test",
    486     srcs = ["python/estimator/python/gan_estimator_test.py"],
    487     shard_count = 1,
    488     srcs_version = "PY2AND3",
    489     tags = ["notsan"],
    490     deps = [
    491         ":gan_estimator",
    492         ":namedtuples",
    493         ":tuple_losses",
    494         "//tensorflow/contrib/layers:layers_py",
    495         "//tensorflow/contrib/learn",
    496         "//tensorflow/core:protos_all_py",
    497         "//tensorflow/python:array_ops",
    498         "//tensorflow/python:client_testlib",
    499         "//tensorflow/python:control_flow_ops",
    500         "//tensorflow/python:dtypes",
    501         "//tensorflow/python:framework_ops",
    502         "//tensorflow/python:parsing_ops",
    503         "//tensorflow/python:summary",
    504         "//tensorflow/python:training",
    505         "//tensorflow/python/estimator:head",
    506         "//tensorflow/python/estimator:model_fn",
    507         "//tensorflow/python/estimator:numpy_io",
    508         "//third_party/py/numpy",
    509         "@six_archive//:six",
    510     ],
    511 )
    512 
    513 py_library(
    514     name = "sliced_wasserstein",
    515     srcs = [
    516         "python/eval/python/sliced_wasserstein.py",
    517         "python/eval/python/sliced_wasserstein_impl.py",
    518     ],
    519     srcs_version = "PY2AND3",
    520     deps = [
    521         "//tensorflow/python:array_ops",
    522         "//tensorflow/python:constant_op",
    523         "//tensorflow/python:linalg_ops",
    524         "//tensorflow/python:math_ops",
    525         "//tensorflow/python:nn",
    526         "//tensorflow/python:nn_ops",
    527         "//tensorflow/python:random_ops",
    528         "//tensorflow/python:script_ops",
    529         "//tensorflow/python:util",
    530         "//third_party/py/numpy",
    531     ],
    532 )
    533 
    534 py_test(
    535     name = "sliced_wasserstein_test",
    536     srcs = ["python/eval/python/sliced_wasserstein_test.py"],
    537     srcs_version = "PY2AND3",
    538     deps = [
    539         ":sliced_wasserstein",
    540         "//tensorflow/python:array_ops",
    541         "//tensorflow/python:client_testlib",
    542         "//tensorflow/python:dtypes",
    543         "//tensorflow/python:random_ops",
    544         "//third_party/py/numpy",
    545     ],
    546 )
    547 
    548 filegroup(
    549     name = "all_files",
    550     srcs = glob(
    551         ["**/*"],
    552         exclude = [
    553             "**/METADATA",
    554             "**/OWNERS",
    555         ],
    556     ),
    557     visibility = ["//tensorflow:__subpackages__"],
    558 )
    559