1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #define LOG_TAG "Operations" 18 19 #include "Operations.h" 20 #include "OperationsUtils.h" 21 22 #include "Tracing.h" 23 24 namespace android { 25 namespace nn { 26 27 template <typename Scalar> 28 bool splitGeneric(const Scalar* inputData, const Shape& inputShape, int32_t axis, 29 const std::vector<Scalar*>* outputDataPtrs, 30 const std::vector<Shape>& outputShapes) { 31 NN_CHECK(handleNegativeAxis(inputShape, &axis)); 32 int outerSize = 1; 33 for (int i = 0; i < axis; ++i) { 34 outerSize *= inputShape.dimensions[i]; 35 } 36 int baseInnerSize = 1; 37 int concatDimensions = getNumberOfDimensions(inputShape); 38 for (int i = axis + 1; i < concatDimensions; ++i) { 39 baseInnerSize *= inputShape.dimensions[i]; 40 } 41 42 const Scalar* inputPtr = inputData; 43 for (int k = 0; k < outerSize; k++) { 44 for (int i = 0; i < outputDataPtrs->size(); ++i) { 45 const int copySize = outputShapes[i].dimensions[axis] * baseInnerSize; 46 memcpy(outputDataPtrs->at(i) + k * copySize, inputPtr, copySize * sizeof(Scalar)); 47 inputPtr += copySize; 48 } 49 } 50 51 return true; 52 } 53 54 bool splitFloat16(const _Float16* inputData, const Shape& inputShape, int32_t axis, 55 const std::vector<_Float16*>* outputDataPtrs, 56 const std::vector<Shape>& outputShapes) { 57 NNTRACE_COMP("splitFloat16"); 58 return splitGeneric<_Float16>(inputData, inputShape, axis, outputDataPtrs, outputShapes); 59 } 60 61 bool splitFloat32(const float* inputData, const Shape& inputShape, int32_t axis, 62 const std::vector<float*>* outputDataPtrs, 63 const std::vector<Shape>& outputShapes) { 64 NNTRACE_COMP("splitFloat32"); 65 return splitGeneric<float>(inputData, inputShape, axis, outputDataPtrs, outputShapes); 66 } 67 68 bool splitQuant8(const uint8_t* inputData, const Shape& inputShape, int32_t axis, 69 const std::vector<uint8_t*>* outputDataPtrs, 70 const std::vector<Shape>& outputShapes) { 71 NNTRACE_COMP("splitQuant8"); 72 return splitGeneric<uint8_t>(inputData, inputShape, axis, outputDataPtrs, outputShapes); 73 } 74 75 bool splitInt32(const int32_t* inputData, const Shape& inputShape, int32_t axis, 76 const std::vector<int32_t*>* outputDataPtrs, 77 const std::vector<Shape>& outputShapes) { 78 NNTRACE_COMP("splitInt32"); 79 return splitGeneric<int32_t>(inputData, inputShape, axis, outputDataPtrs, outputShapes); 80 } 81 82 } // namespace nn 83 } // namespace android 84