Home | History | Annotate | Download | only in lib
      1 # Utilities for building XLA computations.
      2 
      3 licenses(["notice"])  # Apache 2.0
      4 
      5 package(
      6     default_visibility = ["//tensorflow/compiler/tf2xla:friends"],
      7 )
      8 
      9 # Filegroup used to collect source files for dependency checking.
     10 filegroup(
     11     name = "c_srcs",
     12     data = glob([
     13         "**/*.cc",
     14         "**/*.h",
     15     ]),
     16 )
     17 
     18 load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
     19 
     20 cc_library(
     21     name = "batch_dot",
     22     srcs = ["batch_dot.cc"],
     23     hdrs = ["batch_dot.h"],
     24     deps = [
     25         "//tensorflow/compiler/xla:shape_util",
     26         "//tensorflow/compiler/xla:status_macros",
     27         "//tensorflow/compiler/xla:statusor",
     28         "//tensorflow/compiler/xla/client:computation",
     29         "//tensorflow/compiler/xla/client:computation_builder",
     30         "//tensorflow/core:lib",
     31     ],
     32 )
     33 
     34 cc_library(
     35     name = "cholesky",
     36     srcs = ["cholesky.cc"],
     37     hdrs = ["cholesky.h"],
     38     deps = [
     39         ":batch_dot",
     40         ":triangular_solve",
     41         ":util",
     42         "//tensorflow/compiler/xla:literal_util",
     43         "//tensorflow/compiler/xla:shape_util",
     44         "//tensorflow/compiler/xla:status_macros",
     45         "//tensorflow/compiler/xla:statusor",
     46         "//tensorflow/compiler/xla/client:computation",
     47         "//tensorflow/compiler/xla/client:computation_builder",
     48         "//tensorflow/core:lib",
     49     ],
     50 )
     51 
     52 cc_library(
     53     name = "scatter",
     54     srcs = ["scatter.cc"],
     55     hdrs = ["scatter.h"],
     56     deps = [
     57         ":util",
     58         ":while_loop",
     59         "//tensorflow/compiler/xla:literal_util",
     60         "//tensorflow/compiler/xla:shape_util",
     61         "//tensorflow/compiler/xla:status_macros",
     62         "//tensorflow/compiler/xla:statusor",
     63         "//tensorflow/compiler/xla:util",
     64         "//tensorflow/compiler/xla/client:computation",
     65         "//tensorflow/compiler/xla/client:computation_builder",
     66         "//tensorflow/compiler/xla/client/lib:arithmetic",
     67         "//tensorflow/core:lib",
     68     ],
     69 )
     70 
     71 cc_library(
     72     name = "triangular_solve",
     73     srcs = ["triangular_solve.cc"],
     74     hdrs = ["triangular_solve.h"],
     75     deps = [
     76         ":batch_dot",
     77         ":util",
     78         "//tensorflow/compiler/xla:literal_util",
     79         "//tensorflow/compiler/xla:shape_util",
     80         "//tensorflow/compiler/xla:status_macros",
     81         "//tensorflow/compiler/xla:statusor",
     82         "//tensorflow/compiler/xla:types",
     83         "//tensorflow/compiler/xla:util",
     84         "//tensorflow/compiler/xla/client:computation",
     85         "//tensorflow/compiler/xla/client:computation_builder",
     86         "//tensorflow/core:lib",
     87     ],
     88 )
     89 
     90 xla_test(
     91     name = "triangular_solve_test",
     92     srcs = ["triangular_solve_test.cc"],
     93     deps = [
     94         ":triangular_solve",
     95         "//tensorflow/compiler/xla:array2d",
     96         "//tensorflow/compiler/xla:literal_util",
     97         "//tensorflow/compiler/xla:shape_util",
     98         "//tensorflow/compiler/xla:statusor",
     99         "//tensorflow/compiler/xla:test",
    100         "//tensorflow/compiler/xla:types",
    101         "//tensorflow/compiler/xla:xla_data_proto",
    102         "//tensorflow/compiler/xla/client:computation_builder",
    103         "//tensorflow/compiler/xla/client:global_data",
    104         "//tensorflow/compiler/xla/client:local_client",
    105         "//tensorflow/compiler/xla/tests:client_library_test_base",
    106         "//tensorflow/compiler/xla/tests:literal_test_util",
    107         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
    108         "//tensorflow/core:lib",
    109         "//tensorflow/core:test",
    110     ],
    111 )
    112 
    113 cc_library(
    114     name = "util",
    115     srcs = ["util.cc"],
    116     hdrs = ["util.h"],
    117     deps = [
    118         "//tensorflow/compiler/xla:literal_util",
    119         "//tensorflow/compiler/xla:shape_util",
    120         "//tensorflow/compiler/xla:status_macros",
    121         "//tensorflow/compiler/xla:statusor",
    122         "//tensorflow/compiler/xla:util",
    123         "//tensorflow/compiler/xla/client:computation",
    124         "//tensorflow/compiler/xla/client:computation_builder",
    125         "//tensorflow/core:lib",
    126     ],
    127 )
    128 
    129 cc_library(
    130     name = "while_loop",
    131     srcs = ["while_loop.cc"],
    132     hdrs = ["while_loop.h"],
    133     deps = [
    134         ":util",
    135         "//tensorflow/compiler/xla:shape_util",
    136         "//tensorflow/compiler/xla:status_macros",
    137         "//tensorflow/compiler/xla:statusor",
    138         "//tensorflow/compiler/xla/client:computation",
    139         "//tensorflow/compiler/xla/client:computation_builder",
    140         "//tensorflow/core:lib",
    141     ],
    142 )
    143 
    144 # -----------------------------------------------------------------------------
    145 
    146 filegroup(
    147     name = "all_files",
    148     srcs = glob(
    149         ["**/*"],
    150         exclude = [
    151             "**/METADATA",
    152             "**/OWNERS",
    153         ],
    154     ),
    155     visibility = ["//tensorflow:__subpackages__"],
    156 )
    157