Home | History | Annotate | Download | only in test
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "TestMemory.h"
     18 
     19 #include "NeuralNetworksWrapper.h"
     20 
     21 #include <gtest/gtest.h>
     22 #include <sys/mman.h>
     23 #include <sys/types.h>
     24 #include <unistd.h>
     25 
     26 using WrapperCompilation = ::android::nn::wrapper::Compilation;
     27 using WrapperExecution = ::android::nn::wrapper::Execution;
     28 using WrapperMemory = ::android::nn::wrapper::Memory;
     29 using WrapperModel = ::android::nn::wrapper::Model;
     30 using WrapperOperandType = ::android::nn::wrapper::OperandType;
     31 using WrapperResult = ::android::nn::wrapper::Result;
     32 using WrapperType = ::android::nn::wrapper::Type;
     33 
     34 namespace {
     35 
     36 // Tests the various ways to pass weights and input/output data.
     37 class MemoryTest : public ::testing::Test {
     38 protected:
     39     void SetUp() override {}
     40 
     41 };
     42 
     43 TEST_F(MemoryTest, TestFd) {
     44     // Create a file that contains matrix2 and matrix3.
     45     char path[] = "/data/local/tmp/TestMemoryXXXXXX";
     46     int fd = mkstemp(path);
     47     const uint32_t offsetForMatrix2 = 20;
     48     const uint32_t offsetForMatrix3 = 200;
     49     static_assert(offsetForMatrix2 + sizeof(matrix2) < offsetForMatrix3, "matrices overlap");
     50     lseek(fd, offsetForMatrix2, SEEK_SET);
     51     write(fd, matrix2, sizeof(matrix2));
     52     lseek(fd, offsetForMatrix3, SEEK_SET);
     53     write(fd, matrix3, sizeof(matrix3));
     54     fsync(fd);
     55 
     56     WrapperMemory weights(offsetForMatrix3 + sizeof(matrix3), PROT_READ, fd, 0);
     57     ASSERT_TRUE(weights.isValid());
     58 
     59     WrapperModel model;
     60     WrapperOperandType matrixType(WrapperType::TENSOR_FLOAT32, {3, 4});
     61     WrapperOperandType scalarType(WrapperType::INT32, {});
     62     int32_t activation(0);
     63     auto a = model.addOperand(&matrixType);
     64     auto b = model.addOperand(&matrixType);
     65     auto c = model.addOperand(&matrixType);
     66     auto d = model.addOperand(&matrixType);
     67     auto e = model.addOperand(&matrixType);
     68     auto f = model.addOperand(&scalarType);
     69 
     70     model.setOperandValueFromMemory(e, &weights, offsetForMatrix2, sizeof(Matrix3x4));
     71     model.setOperandValueFromMemory(a, &weights, offsetForMatrix3, sizeof(Matrix3x4));
     72     model.setOperandValue(f, &activation, sizeof(activation));
     73     model.addOperation(ANEURALNETWORKS_ADD, {a, c, f}, {b});
     74     model.addOperation(ANEURALNETWORKS_ADD, {b, e, f}, {d});
     75     model.identifyInputsAndOutputs({c}, {d});
     76     ASSERT_TRUE(model.isValid());
     77     model.finish();
     78 
     79     // Test the three node model.
     80     Matrix3x4 actual;
     81     memset(&actual, 0, sizeof(actual));
     82     WrapperCompilation compilation2(&model);
     83     ASSERT_EQ(compilation2.finish(), WrapperResult::NO_ERROR);
     84     WrapperExecution execution2(&compilation2);
     85     ASSERT_EQ(execution2.setInput(0, matrix1, sizeof(Matrix3x4)), WrapperResult::NO_ERROR);
     86     ASSERT_EQ(execution2.setOutput(0, actual, sizeof(Matrix3x4)), WrapperResult::NO_ERROR);
     87     ASSERT_EQ(execution2.compute(), WrapperResult::NO_ERROR);
     88     ASSERT_EQ(CompareMatrices(expected3, actual), 0);
     89 
     90     close(fd);
     91     unlink(path);
     92 }
     93 
     94 }  // end namespace
     95