Home | History | Annotate | Download | only in toco
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include <tuple>
     16 #include <vector>
     17 
     18 #include <gtest/gtest.h>
     19 #include "tensorflow/core/lib/core/status.h"
     20 #include "tensorflow/lite/testing/util.h"
     21 #include "tensorflow/lite/toco/model.h"
     22 #include "tensorflow/lite/toco/toco_port.h"
     23 #include "tensorflow/lite/toco/tooling_util.h"
     24 
     25 namespace toco {
     26 
     27 enum class Agreement { kBroadcast, kExtend, kBroadcastNotExtend, kNeither };
     28 
     29 // A pair of Shapes and whether they should agree up to broadcasting, extending
     30 // or neither.
     31 struct ShapePair {
     32   Shape left;
     33   Shape right;
     34   Agreement agreement;
     35 };
     36 
     37 std::vector<ShapePair> CreateShapePairs() {
     38   return std::vector<ShapePair>(
     39       {// These agree up to broadcast.
     40        {Shape({3}), Shape({3}), Agreement::kBroadcast},
     41        {Shape({256, 256, 3}), Shape({256, 256, 3}), Agreement::kBroadcast},
     42        {Shape({256, 256, 3}), Shape({3}), Agreement::kBroadcast},
     43        {Shape({8, 1, 6, 1}), Shape({7, 1, 5}), Agreement::kBroadcast},
     44        {Shape({}), Shape({3}), Agreement::kBroadcast},
     45        {Shape({}), Shape({3, 1}), Agreement::kBroadcast},
     46 
     47        // These extend (and therefore broadcast).
     48        {Shape({3}), Shape({3}), Agreement::kExtend},
     49        {Shape({256, 256, 3}), Shape({256, 256, 3}), Agreement::kExtend},
     50        {Shape({1, 1, 3}), Shape({1, 1, 3}), Agreement::kExtend},
     51        {Shape({1, 1, 3}), Shape({3}), Agreement::kExtend},
     52        {Shape({1, 1, 3}), Shape({1, 3}), Agreement::kExtend},
     53 
     54        // These strictly broadcast and do not extend.
     55        {Shape({256, 256, 3}), Shape({3}), Agreement::kBroadcastNotExtend},
     56        {Shape({5, 4}), Shape({1}), Agreement::kBroadcastNotExtend},
     57        {Shape({5, 4}), Shape({4}), Agreement::kBroadcastNotExtend},
     58        {Shape({15, 3, 5}), Shape({15, 1, 5}), Agreement::kBroadcastNotExtend},
     59        {Shape({15, 3, 5}), Shape({3, 5}), Agreement::kBroadcastNotExtend},
     60        {Shape({15, 3, 5}), Shape({3, 1}), Agreement::kBroadcastNotExtend},
     61        {Shape({3, 1}), Shape({}), Agreement::kBroadcastNotExtend},
     62 
     63        // These do not broadcast (and therefore also do not extend).
     64        {Shape({3}), Shape({4}), Agreement::kNeither},
     65        {Shape({2, 1}), Shape({8, 4, 3}), Agreement::kNeither}});
     66 }
     67 
     68 // ShapeTest is an empty parameterized test fixture since there is no state.
     69 class ShapeTest : public ::testing::TestWithParam<ShapePair> {};
     70 
     71 TEST_P(ShapeTest, Agrees) {
     72   const ShapePair& param = GetParam();
     73 
     74   switch (param.agreement) {
     75     case Agreement::kBroadcast: {
     76       EXPECT_TRUE(ShapesAgreeUpToBroadcasting(param.left, param.right));
     77       break;
     78     }
     79     case Agreement::kExtend: {
     80       EXPECT_TRUE(ShapesAgreeUpToExtending(param.left, param.right));
     81       // Anything that extends should also broadcast.
     82       EXPECT_TRUE(ShapesAgreeUpToBroadcasting(param.left, param.right));
     83       break;
     84     }
     85     case Agreement::kBroadcastNotExtend: {
     86       // Verify that it strictly broadcasts but does not extend.
     87       EXPECT_TRUE(ShapesAgreeUpToBroadcasting(param.left, param.right));
     88       EXPECT_FALSE(ShapesAgreeUpToExtending(param.left, param.right));
     89       break;
     90     }
     91     case Agreement::kNeither: {
     92       EXPECT_FALSE(ShapesAgreeUpToExtending(param.left, param.right));
     93       EXPECT_FALSE(ShapesAgreeUpToBroadcasting(param.left, param.right));
     94       break;
     95     }
     96   }
     97 }
     98 
     99 INSTANTIATE_TEST_SUITE_P(AgreeBroadcast, ShapeTest,
    100                          ::testing::ValuesIn(CreateShapePairs()));
    101 
    102 static const char kNegativeValuesMessage[] =
    103     "Tensor shape should not include negative values";
    104 static const char kLargeTensorMessage[] = "Tensor shape is too large";
    105 
    106 TEST(NumElementsTest, Int) {
    107   int count;
    108   tensorflow::Status status = tensorflow::Status::OK();
    109 
    110   status = NumElements(std::vector<int>{1024, 1024, 2047}, &count);
    111   EXPECT_TRUE(status.ok());
    112   EXPECT_EQ(count, 2146435072);
    113 
    114   status = NumElements(std::vector<int>{1024, 0, 2048}, &count);
    115   EXPECT_TRUE(status.ok());
    116   EXPECT_EQ(count, 0);
    117 
    118   status = NumElements(std::vector<int>{1, 2, -3}, &count);
    119   EXPECT_EQ(status.error_message(), kNegativeValuesMessage);
    120 
    121   status = NumElements(std::vector<int>{1024, 1024, 2048}, &count);
    122   EXPECT_EQ(status.error_message(), kLargeTensorMessage);
    123 }
    124 
    125 TEST(NumElementsTest, Int32) {
    126   int32_t count;
    127   tensorflow::Status status = tensorflow::Status::OK();
    128 
    129   status = NumElements(std::vector<int32_t>{1024, 1024, 2047}, &count);
    130   EXPECT_TRUE(status.ok());
    131   EXPECT_EQ(count, 2146435072);
    132 
    133   status = NumElements(std::vector<int32_t>{1, 2, -3}, &count);
    134   EXPECT_EQ(status.error_message(), kNegativeValuesMessage);
    135 
    136   status = NumElements(std::vector<int32_t>{1024, 1024, 2048}, &count);
    137   EXPECT_EQ(status.error_message(), kLargeTensorMessage);
    138 }
    139 
    140 TEST(NumElementsTest, Int64) {
    141   int64_t count;
    142   tensorflow::Status status = tensorflow::Status::OK();
    143 
    144   status = NumElements(std::vector<int64_t>{16777216, 16777216, 32767}, &count);
    145   EXPECT_TRUE(status.ok());
    146   EXPECT_EQ(count, 9223090561878065152LL);
    147 
    148   status = NumElements(std::vector<int64_t>{1, 2, -3}, &count);
    149   EXPECT_EQ(status.error_message(), kNegativeValuesMessage);
    150 
    151   status = NumElements(std::vector<int64_t>{16777216, 16777216, 32768}, &count);
    152   EXPECT_EQ(status.error_message(), kLargeTensorMessage);
    153 }
    154 
    155 TEST(NumElementsTest, UnsignedInt32) {
    156   uint32_t count;
    157   tensorflow::Status status = tensorflow::Status::OK();
    158 
    159   status = NumElements(std::vector<uint32_t>{1024, 2048, 2047}, &count);
    160   EXPECT_TRUE(status.ok());
    161   EXPECT_EQ(count, 4292870144);
    162 
    163   status = NumElements(std::vector<int>{1, 2, -3}, &count);
    164   EXPECT_EQ(status.error_message(), kNegativeValuesMessage);
    165 
    166   status = NumElements(std::vector<uint32_t>{1024, 2048, 2048}, &count);
    167   EXPECT_EQ(status.error_message(), kLargeTensorMessage);
    168 }
    169 
    170 TEST(NumElementsTest, UnsignedInt64) {
    171   uint64_t count;
    172   tensorflow::Status status = tensorflow::Status::OK();
    173 
    174   status =
    175       NumElements(std::vector<uint64_t>{16777216, 16777216, 65535}, &count);
    176   EXPECT_TRUE(status.ok());
    177   EXPECT_EQ(count, 18446462598732840960ULL);
    178 
    179   status = NumElements(std::vector<int>{1, 2, -3}, &count);
    180   EXPECT_EQ(status.error_message(), kNegativeValuesMessage);
    181 
    182   status =
    183       NumElements(std::vector<uint64_t>{16777216, 16777216, 65536}, &count);
    184   EXPECT_EQ(status.error_message(), kLargeTensorMessage);
    185 }
    186 
    187 TEST(NumElementsTest, Scalar) {
    188   tensorflow::Status status = tensorflow::Status::OK();
    189 
    190   int32_t count;
    191   status = NumElements(std::vector<int32_t>{}, &count);
    192   EXPECT_TRUE(status.ok());
    193   EXPECT_EQ(count, 1);
    194 
    195   uint64_t countu64;
    196   status = NumElements(std::vector<uint64_t>{}, &countu64);
    197   EXPECT_TRUE(status.ok());
    198   EXPECT_EQ(countu64, 1ULL);
    199 }
    200 
    201 TEST(FusedActivationTest, DefaultsToUnfused) {
    202   EXPECT_TRUE(OperatorSupportsFusedActivation(OperatorType::kAdd));
    203   EXPECT_FALSE(OperatorSupportsFusedActivation(OperatorType::kNone));
    204   EXPECT_FALSE(OperatorSupportsFusedActivation(static_cast<OperatorType>(255)));
    205 }
    206 
    207 }  // namespace toco
    208 
    209 int main(int argc, char** argv) {
    210   ::tflite::LogToStderr();
    211   ::testing::InitGoogleTest(&argc, argv);
    212   ::toco::port::InitGoogleWasDoneElsewhere();
    213   return RUN_ALL_TESTS();
    214 }
    215