1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include "OperationsUtils.cpp" 17 18 #include "gmock/gmock-matchers.h" 19 #include "gtest/gtest.h" 20 21 namespace android { 22 namespace nn { 23 namespace wrapper { 24 25 namespace { 26 using ::testing::ElementsAreArray; 27 } // namespace 28 29 TEST(CalculateBroadcastedShapeTest, Basic) { 30 Shape shape1; 31 Shape shape2; 32 shape1.dimensions = {4, 3, 2, 1}; 33 shape2.dimensions = {3, 1, 5}; 34 35 Shape expectedOutputShape; 36 expectedOutputShape.dimensions = {4, 3, 2, 5}; 37 38 Shape actualOutputShape; 39 EXPECT_TRUE(calculateBroadcastedShape(shape1, shape2, &actualOutputShape)); 40 EXPECT_THAT(actualOutputShape.dimensions, ElementsAreArray(expectedOutputShape.dimensions)); 41 42 EXPECT_TRUE(calculateBroadcastedShape(shape2, shape1, &actualOutputShape)); 43 EXPECT_THAT(actualOutputShape.dimensions, ElementsAreArray(expectedOutputShape.dimensions)); 44 } 45 46 TEST(CalculateBroadcastedShapeTest, FailsOnIncompatible) { 47 Shape shape1; 48 Shape shape2; 49 shape1.dimensions = {5}; 50 shape2.dimensions = {3}; 51 52 Shape actualOutputShape; 53 EXPECT_FALSE(calculateBroadcastedShape(shape1, shape2, &actualOutputShape)); 54 EXPECT_FALSE(calculateBroadcastedShape(shape2, shape1, &actualOutputShape)); 55 } 56 57 static int32_t getExtensionType(uint16_t extensionPrefix, uint16_t typeWithinExtension) { 58 constexpr uint8_t kLowBitsType = 59 static_cast<uint8_t>(Model::ExtensionTypeEncoding::LOW_BITS_TYPE); 60 int32_t type = (extensionPrefix << kLowBitsType) | typeWithinExtension; 61 EXPECT_TRUE(isExtensionOperandType(static_cast<OperandType>(type))); 62 return type; 63 } 64 65 TEST(TensorHasUnspecifiedDimensionsTest, ExtensionTensorWithUnspecifiedRank) { 66 // Regression test for b/124285861. 67 EXPECT_TRUE(tensorHasUnspecifiedDimensions(getExtensionType(1, 0), /*dim=*/nullptr, 68 /*dimCount=*/0)); 69 } 70 71 TEST(ValidateOperandTypeTest, ExtensionTensorWithUnspecifiedRank) { 72 // Regression test for b/124104123. 73 constexpr uint16_t kExtensionPrefix = 1; 74 constexpr uint16_t kTypeWithinExtension = 0; 75 int32_t extensionType = getExtensionType(kExtensionPrefix, kTypeWithinExtension); 76 ANeuralNetworksOperandType type = { 77 .type = extensionType, 78 .dimensionCount = 0, 79 .dimensions = nullptr, 80 }; 81 Extension::OperandTypeInformation info = { 82 .type = kTypeWithinExtension, 83 .isTensor = true, 84 .byteSize = 4, 85 }; 86 EXPECT_EQ(validateOperandType(type, &info, /*tag=*/"test", /*allowPartial=*/true), 87 ANEURALNETWORKS_NO_ERROR); 88 EXPECT_EQ(validateOperandType(type, &info, /*tag=*/"test", /*allowPartial=*/false), 89 ANEURALNETWORKS_BAD_DATA); 90 } 91 92 } // namespace wrapper 93 } // namespace nn 94 } // namespace android 95