Home | History | Annotate | Download | only in optimize
      1 load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
      2 load("//tensorflow:tensorflow.bzl", "tf_cc_test")
      3 
      4 package(default_visibility = [
      5     "//visibility:public",
      6 ])
      7 
      8 licenses(["notice"])  # Apache 2.0
      9 
     10 exports_files(glob([
     11     "testdata/*.bin",
     12 ]))
     13 
     14 cc_library(
     15     name = "quantization_utils",
     16     srcs = ["quantization_utils.cc"],
     17     hdrs = ["quantization_utils.h"],
     18     deps = [
     19         "//tensorflow/lite:framework",
     20         "//tensorflow/lite/c:c_api_internal",
     21         "//tensorflow/lite/kernels/internal:round",
     22         "//tensorflow/lite/kernels/internal:tensor_utils",
     23         "//tensorflow/lite/kernels/internal:types",
     24         "//tensorflow/lite/schema:schema_fbs",
     25         "@com_google_absl//absl/memory",
     26     ],
     27 )
     28 
     29 tf_cc_test(
     30     name = "quantization_utils_test",
     31     srcs = ["quantization_utils_test.cc"],
     32     args = [
     33         "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)",
     34     ],
     35     data = [
     36         "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin",
     37     ],
     38     tags = [
     39         "tflite_not_portable_android",
     40         "tflite_not_portable_ios",
     41     ],
     42     deps = [
     43         ":quantization_utils",
     44         ":test_util",
     45         "//tensorflow/core:framework_internal",
     46         "//tensorflow/core:lib",
     47         "//tensorflow/lite:framework",
     48         "//tensorflow/lite/schema:schema_fbs",
     49         "@com_google_googletest//:gtest",
     50         "@flatbuffers",
     51     ],
     52 )
     53 
     54 cc_library(
     55     name = "quantize_weights",
     56     srcs = ["quantize_weights.cc"],
     57     hdrs = ["quantize_weights.h"],
     58     deps = [
     59         ":quantization_utils",
     60         "@com_google_absl//absl/memory",
     61         "@flatbuffers",
     62         "//tensorflow/lite:framework",
     63         # TODO(suharshs): Move the relevant quantization utils to a non-internal location.
     64         "//tensorflow/lite/kernels/internal:tensor_utils",
     65         "//tensorflow/lite/schema:schema_fbs",
     66         "//tensorflow/core:tflite_portable_logging",
     67     ],
     68 )
     69 
     70 tf_cc_test(
     71     name = "quantize_weights_test",
     72     srcs = ["quantize_weights_test.cc"],
     73     args = [
     74         "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)",
     75     ],
     76     data = [
     77         "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin",
     78         "//tensorflow/lite/tools/optimize:testdata/weight_shared_between_convs.bin",
     79     ],
     80     tags = [
     81         "tflite_not_portable_android",
     82         "tflite_not_portable_ios",
     83     ],
     84     deps = [
     85         ":quantize_weights",
     86         ":test_util",
     87         "//tensorflow/core:framework_internal",
     88         "//tensorflow/core:lib",
     89         "//tensorflow/lite:framework",
     90         "//tensorflow/lite/schema:schema_fbs",
     91         "@com_google_googletest//:gtest",
     92         "@flatbuffers",
     93     ],
     94 )
     95 
     96 cc_library(
     97     name = "subgraph_quantizer",
     98     srcs = ["subgraph_quantizer.cc"],
     99     hdrs = ["subgraph_quantizer.h"],
    100     deps = [
    101         ":quantization_utils",
    102         "//tensorflow/lite:framework",
    103         "//tensorflow/lite/core/api",
    104         "//tensorflow/lite/kernels/internal:round",
    105         "//tensorflow/lite/kernels/internal:tensor_utils",
    106         "//tensorflow/lite/schema:schema_fbs",
    107         "@com_google_absl//absl/memory",
    108         "@flatbuffers",
    109     ],
    110 )
    111 
    112 cc_library(
    113     name = "test_util",
    114     testonly = 1,
    115     srcs = ["test_util.cc"],
    116     hdrs = ["test_util.h"],
    117     deps = [
    118         "//tensorflow/lite:framework",
    119         "//tensorflow/lite/core/api",
    120         "@com_google_googletest//:gtest",
    121         "@flatbuffers",
    122     ],
    123 )
    124 
    125 tf_cc_test(
    126     name = "subgraph_quantizer_test",
    127     srcs = ["subgraph_quantizer_test.cc"],
    128     args = [
    129         "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)",
    130     ],
    131     data = [
    132         "//tensorflow/lite/tools/optimize:testdata/multi_input_add_reshape.bin",
    133         "//tensorflow/lite/tools/optimize:testdata/single_avg_pool_min_minus_5_max_plus_5.bin",
    134         "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin",
    135         "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_minus_127_max_plus_127.bin",
    136         "//tensorflow/lite/tools/optimize:testdata/single_softmax_min_minus_5_max_plus_5.bin",
    137     ],
    138     tags = [
    139         "tflite_not_portable_android",
    140         "tflite_not_portable_ios",
    141     ],
    142     deps = [
    143         ":subgraph_quantizer",
    144         ":test_util",
    145         "//tensorflow/core:framework_internal",
    146         "//tensorflow/core:lib",
    147         "//tensorflow/lite:framework",
    148         "//tensorflow/lite/schema:schema_fbs",
    149         "@com_google_googletest//:gtest",
    150         "@flatbuffers",
    151     ],
    152 )
    153 
    154 cc_library(
    155     name = "quantize_model",
    156     srcs = ["quantize_model.cc"],
    157     hdrs = ["quantize_model.h"],
    158     deps = [
    159         ":subgraph_quantizer",
    160         "//tensorflow/lite:framework",
    161         "//tensorflow/lite/core/api",
    162         "//tensorflow/lite/schema:schema_fbs",
    163         "@com_google_absl//absl/memory",
    164         "@flatbuffers",
    165     ],
    166 )
    167 
    168 tf_cc_test(
    169     name = "quantize_model_test",
    170     srcs = ["quantize_model_test.cc"],
    171     args = [
    172         "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)",
    173     ],
    174     data = [
    175         "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin",
    176     ],
    177     tags = [
    178         "tflite_not_portable_android",
    179         "tflite_not_portable_ios",
    180     ],
    181     deps = [
    182         ":quantize_model",
    183         ":test_util",
    184         "//tensorflow/core:framework_internal",
    185         "//tensorflow/core:lib",
    186         "//tensorflow/lite:framework",
    187         "//tensorflow/lite/schema:schema_fbs",
    188         "@com_google_googletest//:gtest",
    189         "@flatbuffers",
    190     ],
    191 )
    192 
    193 tflite_portable_test_suite()
    194