1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "utils/tensor-view.h" 18 19 #include "gmock/gmock.h" 20 #include "gtest/gtest.h" 21 22 namespace libtextclassifier3 { 23 namespace { 24 25 TEST(TensorViewTest, TestSize) { 26 std::vector<float> data{0.1, 0.2, 0.3, 0.4, 0.5, 0.6}; 27 const TensorView<float> tensor(data.data(), {3, 1, 2}); 28 EXPECT_TRUE(tensor.is_valid()); 29 EXPECT_EQ(tensor.shape(), (std::vector<int>{3, 1, 2})); 30 EXPECT_EQ(tensor.data(), data.data()); 31 EXPECT_EQ(tensor.size(), 6); 32 EXPECT_EQ(tensor.dims(), 3); 33 EXPECT_EQ(tensor.dim(0), 3); 34 EXPECT_EQ(tensor.dim(1), 1); 35 EXPECT_EQ(tensor.dim(2), 2); 36 std::vector<float> output_data(6); 37 EXPECT_TRUE(tensor.copy_to(output_data.data(), output_data.size())); 38 EXPECT_EQ(data, output_data); 39 40 // Should not copy when the output is small. 41 std::vector<float> small_output_data{-1, -1, -1}; 42 EXPECT_FALSE( 43 tensor.copy_to(small_output_data.data(), small_output_data.size())); 44 // The output buffer should not be changed. 45 EXPECT_EQ(small_output_data, (std::vector<float>{-1, -1, -1})); 46 47 const TensorView<float> invalid_tensor = TensorView<float>::Invalid(); 48 EXPECT_FALSE(invalid_tensor.is_valid()); 49 } 50 51 } // namespace 52 } // namespace libtextclassifier3 53