1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <tuple> 16 #include <vector> 17 18 #include <gtest/gtest.h> 19 #include "tensorflow/core/lib/core/status.h" 20 #include "tensorflow/lite/testing/util.h" 21 #include "tensorflow/lite/toco/model.h" 22 #include "tensorflow/lite/toco/toco_port.h" 23 #include "tensorflow/lite/toco/tooling_util.h" 24 25 namespace toco { 26 27 enum class Agreement { kBroadcast, kExtend, kBroadcastNotExtend, kNeither }; 28 29 // A pair of Shapes and whether they should agree up to broadcasting, extending 30 // or neither. 31 struct ShapePair { 32 Shape left; 33 Shape right; 34 Agreement agreement; 35 }; 36 37 std::vector<ShapePair> CreateShapePairs() { 38 return std::vector<ShapePair>( 39 {// These agree up to broadcast. 40 {Shape({3}), Shape({3}), Agreement::kBroadcast}, 41 {Shape({256, 256, 3}), Shape({256, 256, 3}), Agreement::kBroadcast}, 42 {Shape({256, 256, 3}), Shape({3}), Agreement::kBroadcast}, 43 {Shape({8, 1, 6, 1}), Shape({7, 1, 5}), Agreement::kBroadcast}, 44 {Shape({}), Shape({3}), Agreement::kBroadcast}, 45 {Shape({}), Shape({3, 1}), Agreement::kBroadcast}, 46 47 // These extend (and therefore broadcast). 48 {Shape({3}), Shape({3}), Agreement::kExtend}, 49 {Shape({256, 256, 3}), Shape({256, 256, 3}), Agreement::kExtend}, 50 {Shape({1, 1, 3}), Shape({1, 1, 3}), Agreement::kExtend}, 51 {Shape({1, 1, 3}), Shape({3}), Agreement::kExtend}, 52 {Shape({1, 1, 3}), Shape({1, 3}), Agreement::kExtend}, 53 54 // These strictly broadcast and do not extend. 55 {Shape({256, 256, 3}), Shape({3}), Agreement::kBroadcastNotExtend}, 56 {Shape({5, 4}), Shape({1}), Agreement::kBroadcastNotExtend}, 57 {Shape({5, 4}), Shape({4}), Agreement::kBroadcastNotExtend}, 58 {Shape({15, 3, 5}), Shape({15, 1, 5}), Agreement::kBroadcastNotExtend}, 59 {Shape({15, 3, 5}), Shape({3, 5}), Agreement::kBroadcastNotExtend}, 60 {Shape({15, 3, 5}), Shape({3, 1}), Agreement::kBroadcastNotExtend}, 61 {Shape({3, 1}), Shape({}), Agreement::kBroadcastNotExtend}, 62 63 // These do not broadcast (and therefore also do not extend). 64 {Shape({3}), Shape({4}), Agreement::kNeither}, 65 {Shape({2, 1}), Shape({8, 4, 3}), Agreement::kNeither}}); 66 } 67 68 // ShapeTest is an empty parameterized test fixture since there is no state. 69 class ShapeTest : public ::testing::TestWithParam<ShapePair> {}; 70 71 TEST_P(ShapeTest, Agrees) { 72 const ShapePair& param = GetParam(); 73 74 switch (param.agreement) { 75 case Agreement::kBroadcast: { 76 EXPECT_TRUE(ShapesAgreeUpToBroadcasting(param.left, param.right)); 77 break; 78 } 79 case Agreement::kExtend: { 80 EXPECT_TRUE(ShapesAgreeUpToExtending(param.left, param.right)); 81 // Anything that extends should also broadcast. 82 EXPECT_TRUE(ShapesAgreeUpToBroadcasting(param.left, param.right)); 83 break; 84 } 85 case Agreement::kBroadcastNotExtend: { 86 // Verify that it strictly broadcasts but does not extend. 87 EXPECT_TRUE(ShapesAgreeUpToBroadcasting(param.left, param.right)); 88 EXPECT_FALSE(ShapesAgreeUpToExtending(param.left, param.right)); 89 break; 90 } 91 case Agreement::kNeither: { 92 EXPECT_FALSE(ShapesAgreeUpToExtending(param.left, param.right)); 93 EXPECT_FALSE(ShapesAgreeUpToBroadcasting(param.left, param.right)); 94 break; 95 } 96 } 97 } 98 99 INSTANTIATE_TEST_SUITE_P(AgreeBroadcast, ShapeTest, 100 ::testing::ValuesIn(CreateShapePairs())); 101 102 static const char kNegativeValuesMessage[] = 103 "Tensor shape should not include negative values"; 104 static const char kLargeTensorMessage[] = "Tensor shape is too large"; 105 106 TEST(NumElementsTest, Int) { 107 int count; 108 tensorflow::Status status = tensorflow::Status::OK(); 109 110 status = NumElements(std::vector<int>{1024, 1024, 2047}, &count); 111 EXPECT_TRUE(status.ok()); 112 EXPECT_EQ(count, 2146435072); 113 114 status = NumElements(std::vector<int>{1024, 0, 2048}, &count); 115 EXPECT_TRUE(status.ok()); 116 EXPECT_EQ(count, 0); 117 118 status = NumElements(std::vector<int>{1, 2, -3}, &count); 119 EXPECT_EQ(status.error_message(), kNegativeValuesMessage); 120 121 status = NumElements(std::vector<int>{1024, 1024, 2048}, &count); 122 EXPECT_EQ(status.error_message(), kLargeTensorMessage); 123 } 124 125 TEST(NumElementsTest, Int32) { 126 int32_t count; 127 tensorflow::Status status = tensorflow::Status::OK(); 128 129 status = NumElements(std::vector<int32_t>{1024, 1024, 2047}, &count); 130 EXPECT_TRUE(status.ok()); 131 EXPECT_EQ(count, 2146435072); 132 133 status = NumElements(std::vector<int32_t>{1, 2, -3}, &count); 134 EXPECT_EQ(status.error_message(), kNegativeValuesMessage); 135 136 status = NumElements(std::vector<int32_t>{1024, 1024, 2048}, &count); 137 EXPECT_EQ(status.error_message(), kLargeTensorMessage); 138 } 139 140 TEST(NumElementsTest, Int64) { 141 int64_t count; 142 tensorflow::Status status = tensorflow::Status::OK(); 143 144 status = NumElements(std::vector<int64_t>{16777216, 16777216, 32767}, &count); 145 EXPECT_TRUE(status.ok()); 146 EXPECT_EQ(count, 9223090561878065152LL); 147 148 status = NumElements(std::vector<int64_t>{1, 2, -3}, &count); 149 EXPECT_EQ(status.error_message(), kNegativeValuesMessage); 150 151 status = NumElements(std::vector<int64_t>{16777216, 16777216, 32768}, &count); 152 EXPECT_EQ(status.error_message(), kLargeTensorMessage); 153 } 154 155 TEST(NumElementsTest, UnsignedInt32) { 156 uint32_t count; 157 tensorflow::Status status = tensorflow::Status::OK(); 158 159 status = NumElements(std::vector<uint32_t>{1024, 2048, 2047}, &count); 160 EXPECT_TRUE(status.ok()); 161 EXPECT_EQ(count, 4292870144); 162 163 status = NumElements(std::vector<int>{1, 2, -3}, &count); 164 EXPECT_EQ(status.error_message(), kNegativeValuesMessage); 165 166 status = NumElements(std::vector<uint32_t>{1024, 2048, 2048}, &count); 167 EXPECT_EQ(status.error_message(), kLargeTensorMessage); 168 } 169 170 TEST(NumElementsTest, UnsignedInt64) { 171 uint64_t count; 172 tensorflow::Status status = tensorflow::Status::OK(); 173 174 status = 175 NumElements(std::vector<uint64_t>{16777216, 16777216, 65535}, &count); 176 EXPECT_TRUE(status.ok()); 177 EXPECT_EQ(count, 18446462598732840960ULL); 178 179 status = NumElements(std::vector<int>{1, 2, -3}, &count); 180 EXPECT_EQ(status.error_message(), kNegativeValuesMessage); 181 182 status = 183 NumElements(std::vector<uint64_t>{16777216, 16777216, 65536}, &count); 184 EXPECT_EQ(status.error_message(), kLargeTensorMessage); 185 } 186 187 TEST(NumElementsTest, Scalar) { 188 tensorflow::Status status = tensorflow::Status::OK(); 189 190 int32_t count; 191 status = NumElements(std::vector<int32_t>{}, &count); 192 EXPECT_TRUE(status.ok()); 193 EXPECT_EQ(count, 1); 194 195 uint64_t countu64; 196 status = NumElements(std::vector<uint64_t>{}, &countu64); 197 EXPECT_TRUE(status.ok()); 198 EXPECT_EQ(countu64, 1ULL); 199 } 200 201 TEST(FusedActivationTest, DefaultsToUnfused) { 202 EXPECT_TRUE(OperatorSupportsFusedActivation(OperatorType::kAdd)); 203 EXPECT_FALSE(OperatorSupportsFusedActivation(OperatorType::kNone)); 204 EXPECT_FALSE(OperatorSupportsFusedActivation(static_cast<OperatorType>(255))); 205 } 206 207 } // namespace toco 208 209 int main(int argc, char** argv) { 210 ::tflite::LogToStderr(); 211 ::testing::InitGoogleTest(&argc, argv); 212 ::toco::port::InitGoogleWasDoneElsewhere(); 213 return RUN_ALL_TESTS(); 214 } 215