Home | History | Annotate | Download | only in common
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 #include "OperationsUtils.cpp"
     17 
     18 #include "gmock/gmock-matchers.h"
     19 #include "gtest/gtest.h"
     20 
     21 namespace android {
     22 namespace nn {
     23 namespace wrapper {
     24 
     25 namespace {
     26 using ::testing::ElementsAreArray;
     27 }  // namespace
     28 
     29 TEST(CalculateBroadcastedShapeTest, Basic) {
     30     Shape shape1;
     31     Shape shape2;
     32     shape1.dimensions = {4, 3, 2, 1};
     33     shape2.dimensions = {3, 1, 5};
     34 
     35     Shape expectedOutputShape;
     36     expectedOutputShape.dimensions = {4, 3, 2, 5};
     37 
     38     Shape actualOutputShape;
     39     EXPECT_TRUE(calculateBroadcastedShape(shape1, shape2, &actualOutputShape));
     40     EXPECT_THAT(actualOutputShape.dimensions, ElementsAreArray(expectedOutputShape.dimensions));
     41 
     42     EXPECT_TRUE(calculateBroadcastedShape(shape2, shape1, &actualOutputShape));
     43     EXPECT_THAT(actualOutputShape.dimensions, ElementsAreArray(expectedOutputShape.dimensions));
     44 }
     45 
     46 TEST(CalculateBroadcastedShapeTest, FailsOnIncompatible) {
     47     Shape shape1;
     48     Shape shape2;
     49     shape1.dimensions = {5};
     50     shape2.dimensions = {3};
     51 
     52     Shape actualOutputShape;
     53     EXPECT_FALSE(calculateBroadcastedShape(shape1, shape2, &actualOutputShape));
     54     EXPECT_FALSE(calculateBroadcastedShape(shape2, shape1, &actualOutputShape));
     55 }
     56 
     57 static int32_t getExtensionType(uint16_t extensionPrefix, uint16_t typeWithinExtension) {
     58     constexpr uint8_t kLowBitsType =
     59             static_cast<uint8_t>(Model::ExtensionTypeEncoding::LOW_BITS_TYPE);
     60     int32_t type = (extensionPrefix << kLowBitsType) | typeWithinExtension;
     61     EXPECT_TRUE(isExtensionOperandType(static_cast<OperandType>(type)));
     62     return type;
     63 }
     64 
     65 TEST(TensorHasUnspecifiedDimensionsTest, ExtensionTensorWithUnspecifiedRank) {
     66     // Regression test for b/124285861.
     67     EXPECT_TRUE(tensorHasUnspecifiedDimensions(getExtensionType(1, 0), /*dim=*/nullptr,
     68                                                /*dimCount=*/0));
     69 }
     70 
     71 TEST(ValidateOperandTypeTest, ExtensionTensorWithUnspecifiedRank) {
     72     // Regression test for b/124104123.
     73     constexpr uint16_t kExtensionPrefix = 1;
     74     constexpr uint16_t kTypeWithinExtension = 0;
     75     int32_t extensionType = getExtensionType(kExtensionPrefix, kTypeWithinExtension);
     76     ANeuralNetworksOperandType type = {
     77             .type = extensionType,
     78             .dimensionCount = 0,
     79             .dimensions = nullptr,
     80     };
     81     Extension::OperandTypeInformation info = {
     82             .type = kTypeWithinExtension,
     83             .isTensor = true,
     84             .byteSize = 4,
     85     };
     86     EXPECT_EQ(validateOperandType(type, &info, /*tag=*/"test", /*allowPartial=*/true),
     87               ANEURALNETWORKS_NO_ERROR);
     88     EXPECT_EQ(validateOperandType(type, &info, /*tag=*/"test", /*allowPartial=*/false),
     89               ANEURALNETWORKS_BAD_DATA);
     90 }
     91 
     92 }  // namespace wrapper
     93 }  // namespace nn
     94 }  // namespace android
     95