Home | History | Annotate | Download | only in cudnn_rnn
      1 # Description:
      2 #   A Cudnn RNN wrapper.
      3 #   APIs are meant to change over time.
      4 package(
      5     default_visibility = ["//visibility:private"],
      6 )
      7 
      8 licenses(["notice"])  # Apache 2.0
      9 
     10 exports_files(["LICENSE"])
     11 
     12 load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
     13 load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs")
     14 load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
     15 load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
     16 load("//tensorflow:tensorflow.bzl", "cuda_py_test")
     17 load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
     18 load("//tensorflow:tensorflow.bzl", "tf_cc_test")
     19 
     20 tf_custom_op_library(
     21     name = "python/ops/_cudnn_rnn_ops.so",
     22     srcs = [
     23         "kernels/cudnn_rnn_ops.cc",
     24         "ops/cudnn_rnn_ops.cc",
     25     ],
     26     deps = [
     27         "//tensorflow/core/kernels:bounds_check_lib",
     28         "@farmhash_archive//:farmhash",
     29     ],
     30 )
     31 
     32 tf_kernel_library(
     33     name = "cudnn_rnn_kernels",
     34     srcs = ["kernels/cudnn_rnn_ops.cc"],
     35     visibility = ["//visibility:public"],
     36     deps = [
     37         "//tensorflow/core:framework",
     38         "//tensorflow/core:lib",
     39         "//tensorflow/core:lib_internal",
     40         "//tensorflow/core:stream_executor",
     41         "//tensorflow/core/kernels:bounds_check_lib",
     42         "//third_party/eigen3",
     43         "@farmhash_archive//:farmhash",
     44     ],
     45 )
     46 
     47 tf_gen_op_libs(
     48     op_lib_names = ["cudnn_rnn_ops"],
     49     deps = [
     50         "//tensorflow/core:lib",
     51     ],
     52 )
     53 
     54 tf_gen_op_wrapper_py(
     55     name = "cudnn_rnn_ops",
     56     deps = [":cudnn_rnn_ops_op_lib"],
     57 )
     58 
     59 tf_custom_op_py_library(
     60     name = "cudnn_rnn_py",
     61     srcs = [
     62         "__init__.py",
     63         "python/layers/__init__.py",
     64         "python/layers/cudnn_rnn.py",
     65         "python/ops/cudnn_rnn_ops.py",
     66     ],
     67     dso = [
     68         ":python/ops/_cudnn_rnn_ops.so",
     69     ],
     70     kernels = [
     71         ":cudnn_rnn_kernels",
     72         ":cudnn_rnn_ops_op_lib",
     73     ],
     74     srcs_version = "PY2AND3",
     75     visibility = ["//visibility:public"],
     76     deps = [
     77         ":cudnn_rnn_ops",
     78         "//tensorflow/contrib/util:util_py",
     79         "//tensorflow/python:array_ops",
     80         "//tensorflow/python:control_flow_ops",
     81         "//tensorflow/python:framework",
     82         "//tensorflow/python:framework_for_generated_wrappers",
     83         "//tensorflow/python:init_ops",
     84         "//tensorflow/python:layers_base",
     85         "//tensorflow/python:platform",
     86         "//tensorflow/python:state_ops",
     87         "//tensorflow/python:training",
     88         "//tensorflow/python:util",
     89         "//tensorflow/python:variable_scope",
     90     ],
     91 )
     92 
     93 cuda_py_test(
     94     name = "cudnn_rnn_ops_test",
     95     size = "large",
     96     srcs = ["python/kernel_tests/cudnn_rnn_ops_test.py"],
     97     additional_deps = [
     98         ":cudnn_rnn_py",
     99         "//tensorflow/core:protos_all_py",
    100         "//tensorflow/contrib/rnn:rnn_py",
    101         "//tensorflow/python/ops/losses:losses",
    102         "//tensorflow/python:array_ops",
    103         "//tensorflow/python:client_testlib",
    104         "//tensorflow/python:framework",
    105         "//tensorflow/python:framework_for_generated_wrappers",
    106         "//tensorflow/python:framework_test_lib",
    107         "//tensorflow/python:math_ops",
    108         "//tensorflow/python:platform_test",
    109         "//tensorflow/python:random_ops",
    110         "//tensorflow/python:state_ops",
    111         "//tensorflow/python:training",
    112         "//tensorflow/python:variables",
    113     ],
    114     shard_count = 6,
    115     tags = [
    116         "manual",
    117         "requires_cudnn5",
    118     ],
    119 )
    120 
    121 cuda_py_test(
    122     name = "cudnn_rnn_test",
    123     size = "enormous",
    124     srcs = ["python/kernel_tests/cudnn_rnn_test.py"],
    125     additional_deps = [
    126         ":cudnn_rnn_py",
    127         "//tensorflow/core:protos_all_py",
    128         "//tensorflow/contrib/rnn:rnn_py",
    129         "//tensorflow/python/ops/losses:losses",
    130         "//tensorflow/python:array_ops",
    131         "//tensorflow/python:client_testlib",
    132         "//tensorflow/python:framework",
    133         "//tensorflow/python:framework_for_generated_wrappers",
    134         "//tensorflow/python:framework_test_lib",
    135         "//tensorflow/python:math_ops",
    136         "//tensorflow/python:platform_test",
    137         "//tensorflow/python:random_ops",
    138         "//tensorflow/python:state_ops",
    139         "//tensorflow/python:training",
    140         "//tensorflow/python:variables",
    141     ],
    142     shard_count = 6,
    143     tags = [
    144         "manual",
    145         "requires_cudnn5",
    146     ],
    147 )
    148 
    149 cuda_py_test(
    150     name = "cudnn_rnn_ops_benchmark",
    151     size = "small",
    152     srcs = ["python/kernel_tests/cudnn_rnn_ops_benchmark.py"],
    153     additional_deps = [
    154         ":cudnn_rnn_py",
    155         "//tensorflow/contrib/rnn:rnn_py",
    156         "//tensorflow/python:array_ops",
    157         "//tensorflow/python:client",
    158         "//tensorflow/python:client_testlib",
    159         "//tensorflow/python:control_flow_ops",
    160         "//tensorflow/python:framework_for_generated_wrappers",
    161         "//tensorflow/python:framework_test_lib",
    162         "//tensorflow/python:gradients",
    163         "//tensorflow/python:init_ops",
    164         "//tensorflow/python:platform",
    165         "//tensorflow/python:platform_test",
    166         "//tensorflow/python:variables",
    167     ],
    168     tags = [
    169         "noasan",  # http://b/62067814
    170         "nomsan",
    171         "notsan",
    172         "requires_cudnn5",
    173     ],
    174 )
    175 
    176 tf_cc_test(
    177     name = "cudnn_rnn_ops_test_cc",
    178     size = "small",
    179     srcs = [
    180         "ops/cudnn_rnn_ops_test.cc",
    181     ],
    182     deps = [
    183         ":cudnn_rnn_ops_op_lib",
    184         "//tensorflow/core",
    185         "//tensorflow/core:framework",
    186         "//tensorflow/core:lib",
    187         "//tensorflow/core:test",
    188         "//tensorflow/core:test_main",
    189         "//tensorflow/core:testlib",
    190     ],
    191 )
    192 
    193 filegroup(
    194     name = "all_files",
    195     srcs = glob(
    196         ["**/*"],
    197         exclude = [
    198             "**/METADATA",
    199             "**/OWNERS",
    200         ],
    201     ),
    202     visibility = ["//tensorflow:__subpackages__"],
    203 )
    204