1 # Files for using TFGAN framework. 2 package(default_visibility = ["//tensorflow:__subpackages__"]) 3 4 licenses(["notice"]) # Apache 2.0 5 6 exports_files(["LICENSE"]) 7 8 load("//tensorflow:tensorflow.bzl", "py_test") 9 10 py_library( 11 name = "gan", 12 srcs = [ 13 "__init__.py", 14 ], 15 srcs_version = "PY2AND3", 16 deps = [ 17 ":estimator", 18 ":eval", 19 ":features", 20 ":losses", 21 ":namedtuples", 22 ":train", 23 "//tensorflow/python:util", 24 ], 25 ) 26 27 py_library( 28 name = "namedtuples", 29 srcs = ["python/namedtuples.py"], 30 srcs_version = "PY2AND3", 31 ) 32 33 py_library( 34 name = "train", 35 srcs = ["python/train.py"], 36 srcs_version = "PY2AND3", 37 deps = [ 38 ":losses", 39 ":namedtuples", 40 "//tensorflow/contrib/framework:framework_py", 41 "//tensorflow/contrib/slim:learning", 42 "//tensorflow/contrib/training:training_py", 43 "//tensorflow/python:array_ops", 44 "//tensorflow/python:check_ops", 45 "//tensorflow/python:framework_ops", 46 "//tensorflow/python:init_ops", 47 "//tensorflow/python:training", 48 "//tensorflow/python:variable_scope", 49 "//tensorflow/python/ops/distributions", 50 "//tensorflow/python/ops/losses", 51 ], 52 ) 53 54 py_test( 55 name = "train_test", 56 srcs = ["python/train_test.py"], 57 srcs_version = "PY2AND3", 58 tags = ["notsan"], 59 deps = [ 60 ":features", 61 ":namedtuples", 62 ":train", 63 "//tensorflow/contrib/framework:framework_py", 64 "//tensorflow/contrib/slim:learning", 65 "//tensorflow/python:array_ops", 66 "//tensorflow/python:client_testlib", 67 "//tensorflow/python:constant_op", 68 "//tensorflow/python:dtypes", 69 "//tensorflow/python:framework_ops", 70 "//tensorflow/python:random_ops", 71 "//tensorflow/python:random_seed", 72 "//tensorflow/python:training", 73 "//tensorflow/python:variable_scope", 74 "//tensorflow/python:variables", 75 "//tensorflow/python/ops/distributions", 76 "//third_party/py/numpy", 77 ], 78 ) 79 80 py_library( 81 name = "eval", 82 srcs = ["python/eval/__init__.py"], 83 srcs_version = "PY2AND3", 84 deps = [ 85 ":classifier_metrics", 86 ":eval_utils", 87 ":sliced_wasserstein", 88 ":summaries", 89 "//tensorflow/python:util", 90 ], 91 ) 92 93 py_library( 94 name = "estimator", 95 srcs = ["python/estimator/__init__.py"], 96 srcs_version = "PY2AND3", 97 deps = [ 98 ":gan_estimator", 99 ":head", 100 "//tensorflow/python:util", 101 ], 102 ) 103 104 py_library( 105 name = "losses", 106 srcs = ["python/losses/__init__.py"], 107 srcs_version = "PY2AND3", 108 deps = [ 109 ":losses_impl", 110 ":tuple_losses", 111 "//tensorflow/python:util", 112 ], 113 ) 114 115 py_library( 116 name = "features", 117 srcs = ["python/features/__init__.py"], 118 srcs_version = "PY2AND3", 119 deps = [ 120 ":clip_weights", 121 ":conditioning_utils", 122 ":random_tensor_pool", 123 ":virtual_batchnorm", 124 "//tensorflow/python:util", 125 ], 126 ) 127 128 py_library( 129 name = "losses_impl", 130 srcs = ["python/losses/python/losses_impl.py"], 131 srcs_version = "PY2AND3", 132 deps = [ 133 "//tensorflow/contrib/framework:framework_py", 134 "//tensorflow/python:array_ops", 135 "//tensorflow/python:clip_ops", 136 "//tensorflow/python:framework_ops", 137 "//tensorflow/python:gradients", 138 "//tensorflow/python:math_ops", 139 "//tensorflow/python:random_ops", 140 "//tensorflow/python:summary", 141 "//tensorflow/python:tensor_util", 142 "//tensorflow/python:variable_scope", 143 "//tensorflow/python/ops/distributions", 144 "//tensorflow/python/ops/losses", 145 "//third_party/py/numpy", 146 ], 147 ) 148 149 py_test( 150 name = "losses_impl_test", 151 srcs = ["python/losses/python/losses_impl_test.py"], 152 srcs_version = "PY2AND3", 153 deps = [ 154 ":losses_impl", 155 "//tensorflow/python:array_ops", 156 "//tensorflow/python:client_testlib", 157 "//tensorflow/python:clip_ops", 158 "//tensorflow/python:constant_op", 159 "//tensorflow/python:dtypes", 160 "//tensorflow/python:framework_ops", 161 "//tensorflow/python:math_ops", 162 "//tensorflow/python:random_ops", 163 "//tensorflow/python:random_seed", 164 "//tensorflow/python:variable_scope", 165 "//tensorflow/python:variables", 166 "//tensorflow/python/ops/distributions", 167 "//tensorflow/python/ops/losses", 168 ], 169 ) 170 171 py_library( 172 name = "tuple_losses", 173 srcs = [ 174 "python/losses/python/losses_wargs.py", 175 "python/losses/python/tuple_losses.py", 176 "python/losses/python/tuple_losses_impl.py", 177 ], 178 srcs_version = "PY2AND3", 179 deps = [ 180 ":losses_impl", 181 ":namedtuples", 182 "//tensorflow/python:util", 183 ], 184 ) 185 186 py_test( 187 name = "tuple_losses_test", 188 srcs = ["python/losses/python/tuple_losses_test.py"], 189 srcs_version = "PY2AND3", 190 deps = [ 191 ":tuple_losses", 192 "//tensorflow/python:client_testlib", 193 "//tensorflow/python:constant_op", 194 "//tensorflow/python:dtypes", 195 "//tensorflow/python:variables", 196 "//third_party/py/numpy", 197 ], 198 ) 199 200 py_library( 201 name = "conditioning_utils", 202 srcs = [ 203 "python/features/python/conditioning_utils.py", 204 "python/features/python/conditioning_utils_impl.py", 205 ], 206 srcs_version = "PY2AND3", 207 deps = [ 208 "//tensorflow/contrib/layers:layers_py", 209 "//tensorflow/python:array_ops", 210 "//tensorflow/python:embedding_ops", 211 "//tensorflow/python:math_ops", 212 "//tensorflow/python:tensor_util", 213 "//tensorflow/python:util", 214 "//tensorflow/python:variable_scope", 215 ], 216 ) 217 218 py_test( 219 name = "conditioning_utils_test", 220 srcs = ["python/features/python/conditioning_utils_test.py"], 221 srcs_version = "PY2AND3", 222 deps = [ 223 ":conditioning_utils", 224 "//tensorflow/python:array_ops", 225 "//tensorflow/python:client_testlib", 226 "//tensorflow/python:dtypes", 227 ], 228 ) 229 230 py_library( 231 name = "random_tensor_pool", 232 srcs = [ 233 "python/features/python/random_tensor_pool.py", 234 "python/features/python/random_tensor_pool_impl.py", 235 ], 236 srcs_version = "PY2AND3", 237 deps = [ 238 "//tensorflow/python:array_ops", 239 "//tensorflow/python:control_flow_ops", 240 "//tensorflow/python:data_flow_ops", 241 "//tensorflow/python:dtypes", 242 "//tensorflow/python:framework_ops", 243 "//tensorflow/python:random_ops", 244 "//tensorflow/python:util", 245 ], 246 ) 247 248 py_test( 249 name = "random_tensor_pool_test", 250 srcs = ["python/features/python/random_tensor_pool_test.py"], 251 srcs_version = "PY2AND3", 252 deps = [ 253 ":random_tensor_pool", 254 "//tensorflow/python:array_ops", 255 "//tensorflow/python:client_testlib", 256 "//tensorflow/python:dtypes", 257 "//third_party/py/numpy", 258 ], 259 ) 260 261 py_library( 262 name = "virtual_batchnorm", 263 srcs = [ 264 "python/features/python/virtual_batchnorm.py", 265 "python/features/python/virtual_batchnorm_impl.py", 266 ], 267 srcs_version = "PY2AND3", 268 deps = [ 269 "//tensorflow/python:array_ops", 270 "//tensorflow/python:dtypes", 271 "//tensorflow/python:framework_ops", 272 "//tensorflow/python:init_ops", 273 "//tensorflow/python:math_ops", 274 "//tensorflow/python:nn", 275 "//tensorflow/python:tensor_shape", 276 "//tensorflow/python:tensor_util", 277 "//tensorflow/python:util", 278 "//tensorflow/python:variable_scope", 279 ], 280 ) 281 282 py_test( 283 name = "virtual_batchnorm_test", 284 srcs = ["python/features/python/virtual_batchnorm_test.py"], 285 srcs_version = "PY2AND3", 286 deps = [ 287 ":virtual_batchnorm", 288 "//tensorflow/contrib/framework:framework_py", 289 "//tensorflow/python:array_ops", 290 "//tensorflow/python:client_testlib", 291 "//tensorflow/python:constant_op", 292 "//tensorflow/python:dtypes", 293 "//tensorflow/python:layers", 294 "//tensorflow/python:math_ops", 295 "//tensorflow/python:nn", 296 "//tensorflow/python:random_ops", 297 "//tensorflow/python:random_seed", 298 "//tensorflow/python:variable_scope", 299 "//tensorflow/python:variables", 300 "//third_party/py/numpy", 301 ], 302 ) 303 304 py_library( 305 name = "clip_weights", 306 srcs = [ 307 "python/features/python/clip_weights.py", 308 "python/features/python/clip_weights_impl.py", 309 ], 310 srcs_version = "PY2AND3", 311 deps = [ 312 "//tensorflow/contrib/opt:opt_py", 313 "//tensorflow/python:util", 314 ], 315 ) 316 317 py_test( 318 name = "clip_weights_test", 319 srcs = ["python/features/python/clip_weights_test.py"], 320 srcs_version = "PY2AND3", 321 deps = [ 322 ":clip_weights", 323 "//tensorflow/python:client_testlib", 324 "//tensorflow/python:training", 325 "//tensorflow/python:variables", 326 ], 327 ) 328 329 py_library( 330 name = "classifier_metrics", 331 srcs = [ 332 "python/eval/python/classifier_metrics.py", 333 "python/eval/python/classifier_metrics_impl.py", 334 ], 335 srcs_version = "PY2AND3", 336 deps = [ 337 "//tensorflow/contrib/layers:layers_py", 338 "//tensorflow/core:protos_all_py", 339 "//tensorflow/python:array_ops", 340 "//tensorflow/python:dtypes", 341 "//tensorflow/python:framework", 342 "//tensorflow/python:framework_ops", 343 "//tensorflow/python:functional_ops", 344 "//tensorflow/python:image_ops", 345 "//tensorflow/python:linalg_ops", 346 "//tensorflow/python:math_ops", 347 "//tensorflow/python:nn_ops", 348 "//tensorflow/python:platform", 349 "//tensorflow/python:util", 350 ], 351 ) 352 353 py_test( 354 name = "classifier_metrics_test", 355 srcs = ["python/eval/python/classifier_metrics_test.py"], 356 srcs_version = "PY2AND3", 357 deps = [ 358 ":classifier_metrics", 359 "//tensorflow/core:protos_all_py", 360 "//tensorflow/python:array_ops", 361 "//tensorflow/python:client_testlib", 362 "//tensorflow/python:dtypes", 363 "//tensorflow/python:framework_ops", 364 "//tensorflow/python:variables", 365 "//third_party/py/numpy", 366 ], 367 ) 368 369 py_library( 370 name = "eval_utils", 371 srcs = [ 372 "python/eval/python/eval_utils.py", 373 "python/eval/python/eval_utils_impl.py", 374 ], 375 srcs_version = "PY2AND3", 376 deps = [ 377 "//tensorflow/python:array_ops", 378 "//tensorflow/python:framework_ops", 379 "//tensorflow/python:util", 380 ], 381 ) 382 383 py_test( 384 name = "eval_utils_test", 385 srcs = ["python/eval/python/eval_utils_test.py"], 386 srcs_version = "PY2AND3", 387 deps = [ 388 ":eval_utils", 389 "//tensorflow/python:array_ops", 390 "//tensorflow/python:client_testlib", 391 ], 392 ) 393 394 py_library( 395 name = "summaries", 396 srcs = [ 397 "python/eval/python/summaries.py", 398 "python/eval/python/summaries_impl.py", 399 ], 400 srcs_version = "PY2AND3", 401 deps = [ 402 ":eval_utils", 403 ":namedtuples", 404 "//tensorflow/python:array_ops", 405 "//tensorflow/python:framework_ops", 406 "//tensorflow/python:math_ops", 407 "//tensorflow/python:summary", 408 "//tensorflow/python:util", 409 "//tensorflow/python/ops/losses", 410 ], 411 ) 412 413 py_test( 414 name = "summaries_test", 415 srcs = ["python/eval/python/summaries_test.py"], 416 srcs_version = "PY2AND3", 417 deps = [ 418 ":namedtuples", 419 ":summaries", 420 "//tensorflow/python:array_ops", 421 "//tensorflow/python:client_testlib", 422 "//tensorflow/python:framework_ops", 423 "//tensorflow/python:summary", 424 "//tensorflow/python:variable_scope", 425 "//tensorflow/python:variables", 426 ], 427 ) 428 429 py_library( 430 name = "head", 431 srcs = [ 432 "python/estimator/python/head.py", 433 "python/estimator/python/head_impl.py", 434 ], 435 srcs_version = "PY2AND3", 436 deps = [ 437 ":namedtuples", 438 ":train", 439 "//tensorflow/python:framework_ops", 440 "//tensorflow/python:util", 441 "//tensorflow/python/estimator:head", 442 "//tensorflow/python/estimator:model_fn", 443 ], 444 ) 445 446 py_test( 447 name = "head_test", 448 srcs = ["python/estimator/python/head_test.py"], 449 shard_count = 1, 450 srcs_version = "PY2AND3", 451 deps = [ 452 ":head", 453 ":namedtuples", 454 "//tensorflow/python:array_ops", 455 "//tensorflow/python:client_testlib", 456 "//tensorflow/python:math_ops", 457 "//tensorflow/python:training", 458 "//tensorflow/python:variable_scope", 459 "//tensorflow/python/estimator:model_fn", 460 ], 461 ) 462 463 py_library( 464 name = "gan_estimator", 465 srcs = [ 466 "python/estimator/python/gan_estimator.py", 467 "python/estimator/python/gan_estimator_impl.py", 468 ], 469 srcs_version = "PY2AND3", 470 deps = [ 471 ":head", 472 ":namedtuples", 473 ":summaries", 474 ":train", 475 "//tensorflow/contrib/framework:framework_py", 476 "//tensorflow/python:framework_ops", 477 "//tensorflow/python:util", 478 "//tensorflow/python:variable_scope", 479 "//tensorflow/python/estimator", 480 "//tensorflow/python/estimator:model_fn", 481 ], 482 ) 483 484 py_test( 485 name = "gan_estimator_test", 486 srcs = ["python/estimator/python/gan_estimator_test.py"], 487 shard_count = 1, 488 srcs_version = "PY2AND3", 489 tags = ["notsan"], 490 deps = [ 491 ":gan_estimator", 492 ":namedtuples", 493 ":tuple_losses", 494 "//tensorflow/contrib/layers:layers_py", 495 "//tensorflow/contrib/learn", 496 "//tensorflow/core:protos_all_py", 497 "//tensorflow/python:array_ops", 498 "//tensorflow/python:client_testlib", 499 "//tensorflow/python:control_flow_ops", 500 "//tensorflow/python:dtypes", 501 "//tensorflow/python:framework_ops", 502 "//tensorflow/python:parsing_ops", 503 "//tensorflow/python:summary", 504 "//tensorflow/python:training", 505 "//tensorflow/python/estimator:head", 506 "//tensorflow/python/estimator:model_fn", 507 "//tensorflow/python/estimator:numpy_io", 508 "//third_party/py/numpy", 509 "@six_archive//:six", 510 ], 511 ) 512 513 py_library( 514 name = "sliced_wasserstein", 515 srcs = [ 516 "python/eval/python/sliced_wasserstein.py", 517 "python/eval/python/sliced_wasserstein_impl.py", 518 ], 519 srcs_version = "PY2AND3", 520 deps = [ 521 "//tensorflow/python:array_ops", 522 "//tensorflow/python:constant_op", 523 "//tensorflow/python:linalg_ops", 524 "//tensorflow/python:math_ops", 525 "//tensorflow/python:nn", 526 "//tensorflow/python:nn_ops", 527 "//tensorflow/python:random_ops", 528 "//tensorflow/python:script_ops", 529 "//tensorflow/python:util", 530 "//third_party/py/numpy", 531 ], 532 ) 533 534 py_test( 535 name = "sliced_wasserstein_test", 536 srcs = ["python/eval/python/sliced_wasserstein_test.py"], 537 srcs_version = "PY2AND3", 538 deps = [ 539 ":sliced_wasserstein", 540 "//tensorflow/python:array_ops", 541 "//tensorflow/python:client_testlib", 542 "//tensorflow/python:dtypes", 543 "//tensorflow/python:random_ops", 544 "//third_party/py/numpy", 545 ], 546 ) 547 548 filegroup( 549 name = "all_files", 550 srcs = glob( 551 ["**/*"], 552 exclude = [ 553 "**/METADATA", 554 "**/OWNERS", 555 ], 556 ), 557 visibility = ["//tensorflow:__subpackages__"], 558 ) 559