Home | History | Annotate | Download | only in nccl
      1 # Description:
      2 #   Wrap NVIDIA (https://github.com/NVIDIA/nccl) NCCL with tensorflow ops.
      3 #   APIs are meant to change over time.
      4 
      5 package(default_visibility = ["//tensorflow:__subpackages__"])
      6 
      7 licenses(["notice"])  # Apache 2.0
      8 
      9 exports_files(["LICENSE"])
     10 
     11 load(
     12     "//tensorflow:tensorflow.bzl",
     13     "tf_custom_op_library",
     14     "tf_gen_op_libs",
     15     "tf_gen_op_wrapper_py",
     16 )
     17 load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
     18 load("//tensorflow:tensorflow.bzl", "cuda_py_test")
     19 load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
     20 load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
     21 load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
     22 
     23 tf_custom_op_library(
     24     name = "python/ops/_nccl_ops.so",
     25     srcs = [
     26         "ops/nccl_ops.cc",
     27     ],
     28     gpu_srcs = [
     29         "kernels/nccl_manager.cc",
     30         "kernels/nccl_manager.h",
     31         "kernels/nccl_ops.cc",
     32     ],
     33     deps = if_cuda([
     34         "@nccl_archive//:nccl",
     35         "//tensorflow/core:gpu_headers_lib",
     36     ]),
     37 )
     38 
     39 tf_cuda_cc_test(
     40     name = "nccl_manager_test",
     41     size = "medium",
     42     srcs = if_cuda(
     43         [
     44             "kernels/nccl_manager.cc",
     45             "kernels/nccl_manager.h",
     46             "kernels/nccl_manager_test.cc",
     47         ],
     48         [],
     49     ),
     50     # Disabled on jenkins until errors finding nvmlShutdown are found.
     51     tags = [
     52         "manual",
     53         "multi_gpu",
     54         "no_oss",
     55         "notap",
     56     ],
     57     deps =
     58         [
     59             "//tensorflow/core:cuda",
     60             "//tensorflow/core:test",
     61             "//tensorflow/core:test_main",
     62             "//tensorflow/core:testlib",
     63             "@nccl_archive//:nccl",
     64         ],
     65 )
     66 
     67 tf_kernel_library(
     68     name = "nccl_kernels",
     69     srcs = [
     70         "kernels/nccl_manager.cc",
     71         "kernels/nccl_manager.h",
     72         "kernels/nccl_ops.cc",
     73         "kernels/nccl_rewrite.cc",
     74     ],
     75     deps = [
     76         "//tensorflow/core:core_cpu",
     77         "//tensorflow/core:framework",
     78         "//tensorflow/core:gpu_headers_lib",
     79         "//tensorflow/core:lib",
     80         "//tensorflow/core:proto_text",
     81         "//tensorflow/core:stream_executor",
     82         "@nccl_archive//:nccl",
     83     ],
     84     alwayslink = 1,
     85 )
     86 
     87 tf_gen_op_libs(
     88     op_lib_names = ["nccl_ops"],
     89     deps = [
     90         "//tensorflow/core:lib",
     91     ],
     92 )
     93 
     94 tf_gen_op_wrapper_py(
     95     name = "nccl_ops",
     96     deps = [":nccl_ops_op_lib"],
     97 )
     98 
     99 tf_custom_op_py_library(
    100     name = "nccl_py",
    101     srcs = [
    102         "__init__.py",
    103         "python/ops/nccl_ops.py",
    104     ],
    105     dso = [":python/ops/_nccl_ops.so"],
    106     kernels = if_cuda([":nccl_kernels"]) + [
    107         ":nccl_ops_op_lib",
    108     ],
    109     srcs_version = "PY2AND3",
    110     visibility = ["//visibility:public"],
    111     deps = [
    112         ":nccl_ops",
    113         "//tensorflow/contrib/util:util_py",
    114         "//tensorflow/python:device",
    115         "//tensorflow/python:framework_ops",
    116         "//tensorflow/python:platform",
    117         "//tensorflow/python:util",
    118         "//tensorflow/python/eager:context",
    119     ],
    120 )
    121 
    122 cuda_py_test(
    123     name = "nccl_ops_test",
    124     size = "small",
    125     srcs = ["python/ops/nccl_ops_test.py"],
    126     additional_deps = [
    127         ":nccl_py",
    128         "//tensorflow/python:array_ops",
    129         "//tensorflow/python:client_testlib",
    130         "//tensorflow/python:framework_for_generated_wrappers",
    131         "//tensorflow/python:framework_test_lib",
    132         "//tensorflow/python:platform_test",
    133     ],
    134     # Disabled on jenkins until errors finding nvmlShutdown are found.
    135     tags = [
    136         "manual",
    137         "multi_gpu",
    138         "no_oss",
    139         "notap",
    140     ],
    141 )
    142 
    143 filegroup(
    144     name = "all_files",
    145     srcs = glob(
    146         ["**/*"],
    147         exclude = [
    148             "**/METADATA",
    149             "**/OWNERS",
    150         ],
    151     ),
    152     visibility = ["//tensorflow:__subpackages__"],
    153 )
    154