Home | History | Annotate | Download | only in kernels
      1 licenses(["notice"])  # Apache 2.0
      2 
      3 package(
      4     default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
      5 )
      6 
      7 load("//tensorflow:tensorflow.bzl", "tf_copts")
      8 load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
      9 
     10 tf_kernel_library(
     11     name = "xla_ops",
     12     srcs = [
     13         "aggregate_ops.cc",
     14         "arg_op.cc",
     15         "batch_matmul_op.cc",
     16         "batch_norm_op.cc",
     17         "batchtospace_op.cc",
     18         "bcast_ops.cc",
     19         "bias_ops.cc",
     20         "binary_ops.cc",
     21         "cast_op.cc",
     22         "categorical_op.cc",
     23         "cholesky_op.cc",
     24         "concat_op.cc",
     25         "const_op.cc",
     26         "conv_ops.cc",
     27         "cross_op.cc",
     28         "cwise_ops.cc",
     29         "cwise_ops.h",
     30         "depthtospace_op.cc",
     31         "diag_op.cc",
     32         "dynamic_stitch_op.cc",
     33         "elu_op.cc",
     34         "extract_image_patches_op.cc",
     35         "fake_quantize_ops.cc",
     36         "fft_ops.cc",
     37         "fill_op.cc",
     38         "function_ops.cc",
     39         "gather_op.cc",
     40         "gather_op_helpers.h",
     41         "identity_op.cc",
     42         "image_ops.cc",
     43         "image_resize_ops.cc",
     44         "index_ops.cc",
     45         "l2loss_op.cc",
     46         "lrn_ops.cc",
     47         "matmul_op.cc",
     48         "matrix_band_part_op.cc",
     49         "matrix_set_diag_op.cc",
     50         "matrix_triangular_solve_op.cc",
     51         "mirror_pad_op.cc",
     52         "no_op.cc",
     53         "one_hot_op.cc",
     54         "pack_op.cc",
     55         "pad_op.cc",
     56         "pooling_ops.cc",
     57         "quantize_and_dequantize_op.cc",
     58         "random_ops.cc",
     59         "reduction_ops.cc",
     60         "reduction_ops.h",
     61         "reduction_ops_common.cc",
     62         "relu_op.cc",
     63         "reshape_op.cc",
     64         "retval_op.cc",
     65         "reverse_op.cc",
     66         "reverse_sequence_op.cc",
     67         "scan_ops.cc",
     68         "scatter_nd_op.cc",
     69         "segment_reduction_ops.cc",
     70         "select_op.cc",
     71         "sendrecv_ops.cc",
     72         "sequence_ops.cc",
     73         "shape_op.cc",
     74         "shape_util.cc",
     75         "slice_op.cc",
     76         "softmax_op.cc",
     77         "spacetobatch_op.cc",
     78         "spacetodepth_op.cc",
     79         "split_op.cc",
     80         "stack_ops.cc",
     81         "stateless_random_ops.cc",
     82         "strided_slice_op.cc",
     83         "tensor_array_ops.cc",
     84         "tile_ops.cc",
     85         "training_ops.cc",
     86         "transpose_op.cc",
     87         "unary_ops.cc",
     88         "unpack_op.cc",
     89         "variable_ops.cc",
     90     ],
     91     hdrs = [
     92         "index_ops.h",
     93         "shape_util.h",
     94     ],
     95     deps = [
     96         ":while_op",
     97         "//tensorflow/compiler/tf2xla:common",
     98         "//tensorflow/compiler/tf2xla:xla_compiler",
     99         "//tensorflow/compiler/tf2xla/lib:batch_dot",
    100         "//tensorflow/compiler/tf2xla/lib:cholesky",
    101         "//tensorflow/compiler/tf2xla/lib:scatter",
    102         "//tensorflow/compiler/tf2xla/lib:triangular_solve",
    103         "//tensorflow/compiler/tf2xla/lib:util",
    104         "//tensorflow/compiler/tf2xla/lib:while_loop",
    105         "//tensorflow/compiler/tf2xla/ops:sendrecv_ops",
    106         "//tensorflow/compiler/xla:array4d",
    107         "//tensorflow/compiler/xla:literal_util",
    108         "//tensorflow/compiler/xla:shape_util",
    109         "//tensorflow/compiler/xla:status_macros",
    110         "//tensorflow/compiler/xla:util",
    111         "//tensorflow/compiler/xla:xla_data_proto",
    112         "//tensorflow/compiler/xla/client:client_library",
    113         "//tensorflow/compiler/xla/client:computation_builder",
    114         "//tensorflow/compiler/xla/client/lib:arithmetic",
    115         "//tensorflow/core:framework",
    116         "//tensorflow/core:image_ops_op_lib",
    117         "//tensorflow/core:lib",
    118         "//tensorflow/core:linalg_ops_op_lib",
    119         "//tensorflow/core:protos_all_cc",
    120         "//tensorflow/core:spectral_ops_op_lib",
    121         "//tensorflow/core:stateless_random_ops_op_lib",
    122         "//tensorflow/core/kernels:bounds_check",
    123         "//tensorflow/core/kernels:concat_lib",
    124         "//tensorflow/core/kernels:constant_op",
    125         "//tensorflow/core/kernels:control_flow_ops",
    126         "//tensorflow/core/kernels:conv_ops",
    127         "//tensorflow/core/kernels:cwise_op",
    128         "//tensorflow/core/kernels:no_op",
    129         "//tensorflow/core/kernels:ops_util",
    130         "//tensorflow/core/kernels:pooling_ops",
    131         "//tensorflow/core/kernels:random_op",
    132         "//tensorflow/core/kernels:resource_variable_ops",
    133         "//tensorflow/core/kernels:sendrecv_ops",
    134         "//tensorflow/core/kernels:sparse_to_dense_op",
    135         "//tensorflow/core/kernels:stack_ops",
    136         "//tensorflow/core/kernels:training_ops",
    137         "//tensorflow/core/kernels:transpose_op",
    138     ],
    139 )
    140 
    141 tf_kernel_library(
    142     name = "while_op",
    143     srcs = ["while_op.cc"],
    144     hdrs = ["while_op.h"],
    145     deps = [
    146         "//tensorflow/compiler/tf2xla:common",
    147         "//tensorflow/compiler/tf2xla:xla_compiler",
    148         "//tensorflow/compiler/tf2xla/ops:functional_ops",
    149         "//tensorflow/compiler/xla:literal_util",
    150         "//tensorflow/compiler/xla/client:computation_builder",
    151         "//tensorflow/core:framework",
    152         "//tensorflow/core:lib",
    153         "//tensorflow/core:protos_all_cc",
    154     ],
    155 )
    156 
    157 # Kernels that only work on CPU, because they use XLA custom calls.
    158 # Only link this when using the CPU backend for XLA.
    159 tf_kernel_library(
    160     name = "xla_cpu_only_ops",
    161     srcs = ["index_ops_cpu.cc"],
    162     deps = [
    163         ":index_ops_kernel_argmax_float_1d",
    164         ":index_ops_kernel_argmax_float_2d",
    165         "//tensorflow/compiler/tf2xla:common",
    166         "//tensorflow/compiler/tf2xla:xla_compiler",
    167         "//tensorflow/compiler/xla:literal_util",
    168         "//tensorflow/compiler/xla/client:client_library",
    169         "//tensorflow/compiler/xla/client:computation_builder",
    170         "//tensorflow/compiler/xla/client/lib:arithmetic",
    171         "//tensorflow/core:framework",
    172         "//tensorflow/core:lib",
    173         "//tensorflow/core/kernels:argmax_op",
    174         "//tensorflow/core/kernels:bounds_check",
    175     ],
    176 )
    177 
    178 cc_library(
    179     name = "index_ops_kernel_argmax_float_1d",
    180     srcs = ["index_ops_kernel_argmax_float_1d.cc"],
    181     copts = tf_copts(),
    182     visibility = ["//visibility:public"],
    183     deps = [
    184         "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
    185         "//tensorflow/core:framework_lite",
    186         "//third_party/eigen3",
    187     ],
    188     alwayslink = 1,
    189 )
    190 
    191 cc_library(
    192     name = "index_ops_kernel_argmax_float_2d",
    193     srcs = ["index_ops_kernel_argmax_float_2d.cc"],
    194     copts = tf_copts(),
    195     visibility = ["//visibility:public"],
    196     deps = [
    197         "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
    198         "//tensorflow/core:framework_lite",
    199         "//third_party/eigen3",
    200     ],
    201     alwayslink = 1,
    202 )
    203 
    204 # -----------------------------------------------------------------------------
    205 
    206 filegroup(
    207     name = "all_files",
    208     srcs = glob(
    209         ["**/*"],
    210         exclude = [
    211             "**/METADATA",
    212             "**/OWNERS",
    213         ],
    214     ),
    215     visibility = ["//tensorflow:__subpackages__"],
    216 )
    217