1 # Description: 2 # A Cudnn RNN wrapper. 3 # APIs are meant to change over time. 4 package( 5 default_visibility = ["//visibility:private"], 6 ) 7 8 licenses(["notice"]) # Apache 2.0 9 10 exports_files(["LICENSE"]) 11 12 load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") 13 load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") 14 load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") 15 load("//tensorflow:tensorflow.bzl", "tf_kernel_library") 16 load("//tensorflow:tensorflow.bzl", "cuda_py_test") 17 load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") 18 load("//tensorflow:tensorflow.bzl", "tf_cc_test") 19 20 tf_custom_op_library( 21 name = "python/ops/_cudnn_rnn_ops.so", 22 srcs = [ 23 "kernels/cudnn_rnn_ops.cc", 24 "ops/cudnn_rnn_ops.cc", 25 ], 26 deps = [ 27 "//tensorflow/core/kernels:bounds_check_lib", 28 "@farmhash_archive//:farmhash", 29 ], 30 ) 31 32 tf_kernel_library( 33 name = "cudnn_rnn_kernels", 34 srcs = ["kernels/cudnn_rnn_ops.cc"], 35 visibility = ["//visibility:public"], 36 deps = [ 37 "//tensorflow/core:framework", 38 "//tensorflow/core:lib", 39 "//tensorflow/core:lib_internal", 40 "//tensorflow/core:stream_executor", 41 "//tensorflow/core/kernels:bounds_check_lib", 42 "//third_party/eigen3", 43 "@farmhash_archive//:farmhash", 44 ], 45 ) 46 47 tf_gen_op_libs( 48 op_lib_names = ["cudnn_rnn_ops"], 49 deps = [ 50 "//tensorflow/core:lib", 51 ], 52 ) 53 54 tf_gen_op_wrapper_py( 55 name = "cudnn_rnn_ops", 56 deps = [":cudnn_rnn_ops_op_lib"], 57 ) 58 59 tf_custom_op_py_library( 60 name = "cudnn_rnn_py", 61 srcs = [ 62 "__init__.py", 63 "python/layers/__init__.py", 64 "python/layers/cudnn_rnn.py", 65 "python/ops/cudnn_rnn_ops.py", 66 ], 67 dso = [ 68 ":python/ops/_cudnn_rnn_ops.so", 69 ], 70 kernels = [ 71 ":cudnn_rnn_kernels", 72 ":cudnn_rnn_ops_op_lib", 73 ], 74 srcs_version = "PY2AND3", 75 visibility = ["//visibility:public"], 76 deps = [ 77 ":cudnn_rnn_ops", 78 "//tensorflow/contrib/util:util_py", 79 "//tensorflow/python:array_ops", 80 "//tensorflow/python:control_flow_ops", 81 "//tensorflow/python:framework", 82 "//tensorflow/python:framework_for_generated_wrappers", 83 "//tensorflow/python:init_ops", 84 "//tensorflow/python:layers_base", 85 "//tensorflow/python:platform", 86 "//tensorflow/python:state_ops", 87 "//tensorflow/python:training", 88 "//tensorflow/python:util", 89 "//tensorflow/python:variable_scope", 90 ], 91 ) 92 93 cuda_py_test( 94 name = "cudnn_rnn_ops_test", 95 size = "large", 96 srcs = ["python/kernel_tests/cudnn_rnn_ops_test.py"], 97 additional_deps = [ 98 ":cudnn_rnn_py", 99 "//tensorflow/core:protos_all_py", 100 "//tensorflow/contrib/rnn:rnn_py", 101 "//tensorflow/python/ops/losses:losses", 102 "//tensorflow/python:array_ops", 103 "//tensorflow/python:client_testlib", 104 "//tensorflow/python:framework", 105 "//tensorflow/python:framework_for_generated_wrappers", 106 "//tensorflow/python:framework_test_lib", 107 "//tensorflow/python:math_ops", 108 "//tensorflow/python:platform_test", 109 "//tensorflow/python:random_ops", 110 "//tensorflow/python:state_ops", 111 "//tensorflow/python:training", 112 "//tensorflow/python:variables", 113 ], 114 shard_count = 6, 115 tags = [ 116 "manual", 117 "requires_cudnn5", 118 ], 119 ) 120 121 cuda_py_test( 122 name = "cudnn_rnn_test", 123 size = "enormous", 124 srcs = ["python/kernel_tests/cudnn_rnn_test.py"], 125 additional_deps = [ 126 ":cudnn_rnn_py", 127 "//tensorflow/core:protos_all_py", 128 "//tensorflow/contrib/rnn:rnn_py", 129 "//tensorflow/python/ops/losses:losses", 130 "//tensorflow/python:array_ops", 131 "//tensorflow/python:client_testlib", 132 "//tensorflow/python:framework", 133 "//tensorflow/python:framework_for_generated_wrappers", 134 "//tensorflow/python:framework_test_lib", 135 "//tensorflow/python:math_ops", 136 "//tensorflow/python:platform_test", 137 "//tensorflow/python:random_ops", 138 "//tensorflow/python:state_ops", 139 "//tensorflow/python:training", 140 "//tensorflow/python:variables", 141 ], 142 shard_count = 6, 143 tags = [ 144 "manual", 145 "requires_cudnn5", 146 ], 147 ) 148 149 cuda_py_test( 150 name = "cudnn_rnn_ops_benchmark", 151 size = "small", 152 srcs = ["python/kernel_tests/cudnn_rnn_ops_benchmark.py"], 153 additional_deps = [ 154 ":cudnn_rnn_py", 155 "//tensorflow/contrib/rnn:rnn_py", 156 "//tensorflow/python:array_ops", 157 "//tensorflow/python:client", 158 "//tensorflow/python:client_testlib", 159 "//tensorflow/python:control_flow_ops", 160 "//tensorflow/python:framework_for_generated_wrappers", 161 "//tensorflow/python:framework_test_lib", 162 "//tensorflow/python:gradients", 163 "//tensorflow/python:init_ops", 164 "//tensorflow/python:platform", 165 "//tensorflow/python:platform_test", 166 "//tensorflow/python:variables", 167 ], 168 tags = [ 169 "noasan", # http://b/62067814 170 "nomsan", 171 "notsan", 172 "requires_cudnn5", 173 ], 174 ) 175 176 tf_cc_test( 177 name = "cudnn_rnn_ops_test_cc", 178 size = "small", 179 srcs = [ 180 "ops/cudnn_rnn_ops_test.cc", 181 ], 182 deps = [ 183 ":cudnn_rnn_ops_op_lib", 184 "//tensorflow/core", 185 "//tensorflow/core:framework", 186 "//tensorflow/core:lib", 187 "//tensorflow/core:test", 188 "//tensorflow/core:test_main", 189 "//tensorflow/core:testlib", 190 ], 191 ) 192 193 filegroup( 194 name = "all_files", 195 srcs = glob( 196 ["**/*"], 197 exclude = [ 198 "**/METADATA", 199 "**/OWNERS", 200 ], 201 ), 202 visibility = ["//tensorflow:__subpackages__"], 203 ) 204