1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "HashtableLookup.h" 18 19 #include "CpuExecutor.h" 20 #include "Operations.h" 21 22 #include "Tracing.h" 23 24 namespace android { 25 namespace nn { 26 27 namespace { 28 29 int greater(const void* a, const void* b) { 30 return *static_cast<const int*>(a) - *static_cast<const int*>(b); 31 } 32 33 } // anonymous namespace 34 35 HashtableLookup::HashtableLookup(const Operation& operation, 36 std::vector<RunTimeOperandInfo>& operands) { 37 lookup_ = GetInput(operation, operands, kLookupTensor); 38 key_ = GetInput(operation, operands, kKeyTensor); 39 value_ = GetInput(operation, operands, kValueTensor); 40 41 output_ = GetOutput(operation, operands, kOutputTensor); 42 hits_ = GetOutput(operation, operands, kHitsTensor); 43 } 44 45 bool HashtableLookup::Eval() { 46 NNTRACE_COMP("HashtableLookup::Eval"); 47 const int num_rows = value_->shape().dimensions[0]; 48 const int row_bytes = nonExtensionOperandSizeOfData(value_->type, value_->dimensions) / num_rows; 49 void* pointer = nullptr; 50 51 for (int i = 0; i < static_cast<int>(lookup_->shape().dimensions[0]); i++) { 52 int idx = -1; 53 pointer = bsearch(lookup_->buffer + sizeof(int) * i, key_->buffer, 54 num_rows, sizeof(int), greater); 55 if (pointer != nullptr) { 56 idx = 57 (reinterpret_cast<uint8_t*>(pointer) - key_->buffer) / sizeof(float); 58 } 59 60 if (idx >= num_rows || idx < 0) { 61 memset(output_->buffer + i * row_bytes, 0, row_bytes); 62 hits_->buffer[i] = 0; 63 } else { 64 memcpy(output_->buffer + i * row_bytes, value_->buffer + idx * row_bytes, 65 row_bytes); 66 hits_->buffer[i] = 1; 67 } 68 } 69 70 return true; 71 } 72 73 } // namespace nn 74 } // namespace android 75