Home | History | Annotate | Download | only in java
      1 # Description:
      2 # TensorFlow Java API.
      3 
      4 package(default_visibility = ["//visibility:private"])
      5 
      6 licenses(["notice"])  # Apache 2.0
      7 
      8 load(":build_defs.bzl", "JAVACOPTS")
      9 load(":src/gen/gen_ops.bzl", "tf_java_op_gen_srcjar")
     10 load(
     11     "//tensorflow:tensorflow.bzl",
     12     "tf_binary_additional_srcs",
     13     "tf_cc_binary",
     14     "tf_copts",
     15     "tf_custom_op_library",
     16     "tf_java_test",
     17     "tf_cc_test",
     18 )
     19 
     20 java_library(
     21     name = "tensorflow",
     22     srcs = [
     23         ":java_op_sources",
     24         ":java_sources",
     25     ],
     26     data = [":libtensorflow_jni"],
     27     javacopts = JAVACOPTS,
     28     plugins = [":processor"],
     29     visibility = ["//visibility:public"],
     30 )
     31 
     32 # NOTE(ashankar): Rule to include the Java API in the Android Inference Library
     33 # .aar. At some point, might make sense for a .aar rule here instead.
     34 filegroup(
     35     name = "java_sources",
     36     srcs = glob([
     37         "src/main/java/org/tensorflow/*.java",
     38         "src/main/java/org/tensorflow/types/*.java",
     39     ]),
     40     visibility = [
     41         "//tensorflow/contrib/android:__pkg__",
     42         "//tensorflow/java:__pkg__",
     43     ],
     44 )
     45 
     46 java_plugin(
     47     name = "processor",
     48     generates_api = True,
     49     processor_class = "org.tensorflow.processor.OperatorProcessor",
     50     visibility = ["//visibility:public"],
     51     deps = [":processor_library"],
     52 )
     53 
     54 java_library(
     55     name = "processor_library",
     56     srcs = glob(["src/gen/java/org/tensorflow/processor/**/*.java"]),
     57     javacopts = JAVACOPTS,
     58     resources = glob(["src/gen/resources/META-INF/services/javax.annotation.processing.Processor"]),
     59 )
     60 
     61 filegroup(
     62     name = "java_op_sources",
     63     srcs = glob(["src/main/java/org/tensorflow/op/**/*.java"]) + [
     64         ":java_op_gen_sources",
     65     ],
     66     visibility = [
     67         "//tensorflow/java:__pkg__",
     68     ],
     69 )
     70 
     71 tf_java_op_gen_srcjar(
     72     name = "java_op_gen_sources",
     73     gen_base_package = "org.tensorflow.op",
     74     gen_tool = "java_op_gen_tool",
     75     ops_libs = [
     76         "array_ops",
     77         "candidate_sampling_ops",
     78         "control_flow_ops",
     79         "data_flow_ops",
     80         "image_ops",
     81         "io_ops",
     82         "linalg_ops",
     83         "logging_ops",
     84         "math_ops",
     85         "nn_ops",
     86         "no_op",
     87         "parsing_ops",
     88         "random_ops",
     89         "sparse_ops",
     90         "state_ops",
     91         "string_ops",
     92         "training_ops",
     93         "user_ops",
     94     ],
     95 )
     96 
     97 # Build the gen tool as a library, as it will be linked to a core/ops binary
     98 # file before making it an executable. See tf_java_op_gen_srcjar().
     99 cc_library(
    100     name = "java_op_gen_tool",
    101     srcs = [
    102         "src/gen/cc/op_gen_main.cc",
    103     ],
    104     copts = tf_copts(),
    105     deps = [
    106         ":java_op_gen_lib",
    107     ],
    108 )
    109 
    110 cc_library(
    111     name = "java_op_gen_lib",
    112     srcs = [
    113         "src/gen/cc/op_generator.cc",
    114         "src/gen/cc/source_writer.cc",
    115     ],
    116     hdrs = [
    117         "src/gen/cc/java_defs.h",
    118         "src/gen/cc/op_generator.h",
    119         "src/gen/cc/source_writer.h",
    120     ],
    121     copts = tf_copts(),
    122     deps = [
    123         "//tensorflow/core:framework",
    124         "//tensorflow/core:framework_internal",
    125         "//tensorflow/core:lib",
    126         "//tensorflow/core:lib_internal",
    127     ],
    128 )
    129 
    130 java_library(
    131     name = "testutil",
    132     testonly = 1,
    133     srcs = ["src/test/java/org/tensorflow/TestUtil.java"],
    134     javacopts = JAVACOPTS,
    135     deps = [":tensorflow"],
    136 )
    137 
    138 tf_java_test(
    139     name = "GraphTest",
    140     size = "small",
    141     srcs = ["src/test/java/org/tensorflow/GraphTest.java"],
    142     javacopts = JAVACOPTS,
    143     test_class = "org.tensorflow.GraphTest",
    144     deps = [
    145         ":tensorflow",
    146         ":testutil",
    147         "@junit",
    148     ],
    149 )
    150 
    151 tf_java_test(
    152     name = "OperationBuilderTest",
    153     size = "small",
    154     srcs = ["src/test/java/org/tensorflow/OperationBuilderTest.java"],
    155     javacopts = JAVACOPTS,
    156     test_class = "org.tensorflow.OperationBuilderTest",
    157     deps = [
    158         ":tensorflow",
    159         ":testutil",
    160         "@junit",
    161     ],
    162 )
    163 
    164 tf_java_test(
    165     name = "OperationTest",
    166     size = "small",
    167     srcs = ["src/test/java/org/tensorflow/OperationTest.java"],
    168     javacopts = JAVACOPTS,
    169     test_class = "org.tensorflow.OperationTest",
    170     deps = [
    171         ":tensorflow",
    172         ":testutil",
    173         "@junit",
    174     ],
    175 )
    176 
    177 tf_java_test(
    178     name = "SavedModelBundleTest",
    179     size = "small",
    180     srcs = ["src/test/java/org/tensorflow/SavedModelBundleTest.java"],
    181     data = ["//tensorflow/cc/saved_model:saved_model_half_plus_two"],
    182     javacopts = JAVACOPTS,
    183     test_class = "org.tensorflow.SavedModelBundleTest",
    184     deps = [
    185         ":tensorflow",
    186         ":testutil",
    187         "@junit",
    188     ],
    189 )
    190 
    191 tf_java_test(
    192     name = "SessionTest",
    193     size = "small",
    194     srcs = ["src/test/java/org/tensorflow/SessionTest.java"],
    195     javacopts = JAVACOPTS,
    196     test_class = "org.tensorflow.SessionTest",
    197     deps = [
    198         ":tensorflow",
    199         ":testutil",
    200         "@junit",
    201     ],
    202 )
    203 
    204 tf_java_test(
    205     name = "ShapeTest",
    206     size = "small",
    207     srcs = ["src/test/java/org/tensorflow/ShapeTest.java"],
    208     javacopts = JAVACOPTS,
    209     test_class = "org.tensorflow.ShapeTest",
    210     deps = [
    211         ":tensorflow",
    212         ":testutil",
    213         "@junit",
    214     ],
    215 )
    216 
    217 tf_custom_op_library(
    218     name = "my_test_op.so",
    219     srcs = ["src/test/native/my_test_op.cc"],
    220 )
    221 
    222 tf_java_test(
    223     name = "TensorFlowTest",
    224     size = "small",
    225     srcs = ["src/test/java/org/tensorflow/TensorFlowTest.java"],
    226     data = [":my_test_op.so"],
    227     javacopts = JAVACOPTS,
    228     test_class = "org.tensorflow.TensorFlowTest",
    229     deps = [
    230         ":tensorflow",
    231         "@junit",
    232     ],
    233 )
    234 
    235 tf_java_test(
    236     name = "TensorTest",
    237     size = "small",
    238     srcs = ["src/test/java/org/tensorflow/TensorTest.java"],
    239     javacopts = JAVACOPTS,
    240     test_class = "org.tensorflow.TensorTest",
    241     deps = [
    242         ":tensorflow",
    243         ":testutil",
    244         "@junit",
    245     ],
    246 )
    247 
    248 tf_java_test(
    249     name = "ScopeTest",
    250     size = "small",
    251     srcs = ["src/test/java/org/tensorflow/op/ScopeTest.java"],
    252     javacopts = JAVACOPTS,
    253     test_class = "org.tensorflow.op.ScopeTest",
    254     deps = [
    255         ":tensorflow",
    256         ":testutil",
    257         "@junit",
    258     ],
    259 )
    260 
    261 tf_java_test(
    262     name = "PrimitiveOpTest",
    263     size = "small",
    264     srcs = ["src/test/java/org/tensorflow/op/PrimitiveOpTest.java"],
    265     javacopts = JAVACOPTS,
    266     test_class = "org.tensorflow.op.PrimitiveOpTest",
    267     deps = [
    268         ":tensorflow",
    269         ":testutil",
    270         "@junit",
    271     ],
    272 )
    273 
    274 tf_java_test(
    275     name = "OperandsTest",
    276     size = "small",
    277     srcs = ["src/test/java/org/tensorflow/op/OperandsTest.java"],
    278     javacopts = JAVACOPTS,
    279     test_class = "org.tensorflow.op.OperandsTest",
    280     deps = [
    281         ":tensorflow",
    282         ":testutil",
    283         "@junit",
    284     ],
    285 )
    286 
    287 tf_java_test(
    288     name = "ConstantTest",
    289     size = "small",
    290     srcs = ["src/test/java/org/tensorflow/op/core/ConstantTest.java"],
    291     javacopts = JAVACOPTS,
    292     test_class = "org.tensorflow.op.core.ConstantTest",
    293     deps = [
    294         ":tensorflow",
    295         ":testutil",
    296         "@junit",
    297     ],
    298 )
    299 
    300 filegroup(
    301     name = "processor_test_resources",
    302     srcs = glob([
    303         "src/test/resources/org/tensorflow/**/*.java",
    304         "src/main/java/org/tensorflow/op/annotation/Operator.java",
    305     ]),
    306 )
    307 
    308 tf_cc_test(
    309     name = "source_writer_test",
    310     size = "small",
    311     srcs = [
    312         "src/gen/cc/source_writer_test.cc",
    313     ],
    314     deps = [
    315         ":java_op_gen_lib",
    316         "//tensorflow/core:lib",
    317         "//tensorflow/core:test",
    318         "//tensorflow/core:test_main",
    319     ],
    320 )
    321 
    322 filegroup(
    323     name = "libtensorflow_jni",
    324     srcs = select({
    325         "//tensorflow:darwin": [":libtensorflow_jni.dylib"],
    326         "//conditions:default": [":libtensorflow_jni.so"],
    327     }),
    328     visibility = ["//visibility:public"],
    329 )
    330 
    331 LINKER_VERSION_SCRIPT = ":config/version_script.lds"
    332 
    333 LINKER_EXPORTED_SYMBOLS = ":config/exported_symbols.lds"
    334 
    335 tf_cc_binary(
    336     name = "libtensorflow_jni.so",
    337     # Set linker options to strip out anything except the JNI
    338     # symbols from the library. This reduces the size of the library
    339     # considerably (~50% as of January 2017).
    340     linkopts = select({
    341         "//tensorflow:debug": [],  # Disable all custom linker options in debug mode
    342         "//tensorflow:darwin": [
    343             "-Wl,-exported_symbols_list",  # This line must be directly followed by LINKER_EXPORTED_SYMBOLS
    344             LINKER_EXPORTED_SYMBOLS,
    345         ],
    346         "//tensorflow:windows": [],
    347         "//tensorflow:windows_msvc": [],
    348         "//conditions:default": [
    349             "-z defs",
    350             "-s",
    351             "-Wl,--version-script",  #  This line must be directly followed by LINKER_VERSION_SCRIPT
    352             LINKER_VERSION_SCRIPT,
    353         ],
    354     }),
    355     linkshared = 1,
    356     linkstatic = 1,
    357     deps = [
    358         "//tensorflow/java/src/main/native",
    359         LINKER_VERSION_SCRIPT,
    360         LINKER_EXPORTED_SYMBOLS,
    361     ],
    362 )
    363 
    364 genrule(
    365     name = "pom",
    366     outs = ["pom.xml"],
    367     cmd = "$(location generate_pom) >$@",
    368     output_to_bindir = 1,
    369     tools = [":generate_pom"] + tf_binary_additional_srcs(),
    370 )
    371 
    372 tf_cc_binary(
    373     name = "generate_pom",
    374     srcs = ["generate_pom.cc"],
    375     deps = ["//tensorflow/c:c_api"],
    376 )
    377 
    378 # System.loadLibrary() on OS X looks for ".dylib" or ".jnilib"
    379 # and no ".so". If and when https://github.com/bazelbuild/bazel/issues/914
    380 # is resolved, perhaps this workaround rule can be removed.
    381 genrule(
    382     name = "darwin-compat",
    383     srcs = [":libtensorflow_jni.so"],
    384     outs = ["libtensorflow_jni.dylib"],
    385     cmd = "cp $< $@",
    386     output_to_bindir = 1,
    387 )
    388 
    389 filegroup(
    390     name = "all_files",
    391     srcs = glob(
    392         ["**/*"],
    393         exclude = [
    394             "**/METADATA",
    395             "**/OWNERS",
    396         ],
    397     ),
    398     visibility = ["//tensorflow:__subpackages__"],
    399 )
    400