1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "LSHProjection.h" 18 19 #include "CpuExecutor.h" 20 #include "Tracing.h" 21 #include "Utils.h" 22 23 #include "utils/hash/farmhash.h" 24 25 namespace android { 26 namespace nn { 27 28 LSHProjection::LSHProjection(const Operation& operation, 29 std::vector<RunTimeOperandInfo>& operands) { 30 input_ = GetInput(operation, operands, kInputTensor); 31 weight_ = GetInput(operation, operands, kWeightTensor); 32 hash_ = GetInput(operation, operands, kHashTensor); 33 34 type_ = static_cast<LSHProjectionType>( 35 getScalarData<int32_t>(*GetInput(operation, operands, kTypeParam))); 36 37 output_ = GetOutput(operation, operands, kOutputTensor); 38 } 39 40 bool LSHProjection::Prepare(const Operation& operation, std::vector<RunTimeOperandInfo>& operands, 41 Shape* outputShape) { 42 const int num_inputs = NumInputsWithValues(operation, operands); 43 NN_CHECK(num_inputs == 3 || num_inputs == 4); 44 NN_CHECK_EQ(NumOutputs(operation), 1); 45 46 const RunTimeOperandInfo* hash = GetInput(operation, operands, kHashTensor); 47 NN_CHECK_EQ(NumDimensions(hash), 2); 48 // Support up to 32 bits. 49 NN_CHECK(SizeOfDimension(hash, 1) <= 32); 50 51 const RunTimeOperandInfo* input = GetInput(operation, operands, kInputTensor); 52 NN_CHECK(NumDimensions(input) >= 1); 53 54 auto type = static_cast<LSHProjectionType>( 55 getScalarData<int32_t>(operands[operation.inputs[kTypeParam]])); 56 switch (type) { 57 case LSHProjectionType_SPARSE: 58 case LSHProjectionType_SPARSE_DEPRECATED: 59 NN_CHECK(NumInputsWithValues(operation, operands) == 3); 60 outputShape->dimensions = {SizeOfDimension(hash, 0)}; 61 break; 62 case LSHProjectionType_DENSE: { 63 RunTimeOperandInfo* weight = GetInput(operation, operands, kWeightTensor); 64 NN_CHECK_EQ(NumInputsWithValues(operation, operands), 4); 65 NN_CHECK_EQ(NumDimensions(weight), 1); 66 NN_CHECK_EQ(SizeOfDimension(weight, 0), SizeOfDimension(input, 0)); 67 outputShape->dimensions = {SizeOfDimension(hash, 0) * SizeOfDimension(hash, 1)}; 68 break; 69 } 70 default: 71 return false; 72 } 73 74 outputShape->type = OperandType::TENSOR_INT32; 75 outputShape->offset = 0; 76 outputShape->scale = 0.f; 77 78 return true; 79 } 80 81 // Compute sign bit of dot product of hash(seed, input) and weight. 82 // NOTE: use float as seed, and convert it to double as a temporary solution 83 // to match the trained model. This is going to be changed once the new 84 // model is trained in an optimized method. 85 // 86 template <typename T> 87 int runningSignBit(const RunTimeOperandInfo* input, const RunTimeOperandInfo* weight, float seed) { 88 double score = 0.0; 89 int input_item_bytes = nonExtensionOperandSizeOfData(input->type, input->dimensions) / 90 SizeOfDimension(input, 0); 91 char* input_ptr = (char*)(input->buffer); 92 93 const size_t seed_size = sizeof(seed); 94 const size_t key_bytes = seed_size + input_item_bytes; 95 std::unique_ptr<char[]> key(new char[key_bytes]); 96 97 for (uint32_t i = 0; i < SizeOfDimension(input, 0); ++i) { 98 // Create running hash id and value for current dimension. 99 memcpy(key.get(), &seed, seed_size); 100 memcpy(key.get() + seed_size, input_ptr, input_item_bytes); 101 102 int64_t hash_signature = farmhash::Fingerprint64(key.get(), key_bytes); 103 double running_value = static_cast<double>(hash_signature); 104 input_ptr += input_item_bytes; 105 if (weight->lifetime == OperandLifeTime::NO_VALUE) { 106 score += running_value; 107 } else { 108 score += static_cast<double>(reinterpret_cast<T*>(weight->buffer)[i]) * running_value; 109 } 110 } 111 112 return (score > 0) ? 1 : 0; 113 } 114 115 template <typename T> 116 void SparseLshProjection(LSHProjectionType type, const RunTimeOperandInfo* hash, 117 const RunTimeOperandInfo* input, const RunTimeOperandInfo* weight, 118 int32_t* out_buf) { 119 int num_hash = SizeOfDimension(hash, 0); 120 int num_bits = SizeOfDimension(hash, 1); 121 for (int i = 0; i < num_hash; i++) { 122 int32_t hash_signature = 0; 123 for (int j = 0; j < num_bits; j++) { 124 T seed = reinterpret_cast<T*>(hash->buffer)[i * num_bits + j]; 125 int bit = runningSignBit<T>(input, weight, static_cast<float>(seed)); 126 hash_signature = (hash_signature << 1) | bit; 127 } 128 if (type == LSHProjectionType_SPARSE_DEPRECATED) { 129 *out_buf++ = hash_signature; 130 } else { 131 *out_buf++ = hash_signature + i * (1 << num_bits); 132 } 133 } 134 } 135 136 template <typename T> 137 void DenseLshProjection(const RunTimeOperandInfo* hash, const RunTimeOperandInfo* input, 138 const RunTimeOperandInfo* weight, int32_t* out_buf) { 139 int num_hash = SizeOfDimension(hash, 0); 140 int num_bits = SizeOfDimension(hash, 1); 141 for (int i = 0; i < num_hash; i++) { 142 for (int j = 0; j < num_bits; j++) { 143 T seed = reinterpret_cast<T*>(hash->buffer)[i * num_bits + j]; 144 int bit = runningSignBit<T>(input, weight, static_cast<float>(seed)); 145 *out_buf++ = bit; 146 } 147 } 148 } 149 150 template <typename T> 151 bool LSHProjection::Eval() { 152 NNTRACE_COMP("LSHProjection::Eval"); 153 154 int32_t* out_buf = reinterpret_cast<int32_t*>(output_->buffer); 155 156 switch (type_) { 157 case LSHProjectionType_DENSE: 158 DenseLshProjection<T>(hash_, input_, weight_, out_buf); 159 break; 160 case LSHProjectionType_SPARSE: 161 case LSHProjectionType_SPARSE_DEPRECATED: 162 SparseLshProjection<T>(type_, hash_, input_, weight_, out_buf); 163 break; 164 default: 165 return false; 166 } 167 return true; 168 } 169 170 template bool LSHProjection::Eval<float>(); 171 template bool LSHProjection::Eval<_Float16>(); 172 173 template int runningSignBit<float>(const RunTimeOperandInfo* input, 174 const RunTimeOperandInfo* weight, float seed); 175 template int runningSignBit<_Float16>(const RunTimeOperandInfo* input, 176 const RunTimeOperandInfo* weight, float seed); 177 178 template void SparseLshProjection<float>(LSHProjectionType type, const RunTimeOperandInfo* hash, 179 const RunTimeOperandInfo* input, 180 const RunTimeOperandInfo* weight, int32_t* outBuffer); 181 template void SparseLshProjection<_Float16>(LSHProjectionType type, const RunTimeOperandInfo* hash, 182 const RunTimeOperandInfo* input, 183 const RunTimeOperandInfo* weight, int32_t* outBuffer); 184 185 template void DenseLshProjection<float>(const RunTimeOperandInfo* hash, 186 const RunTimeOperandInfo* input, 187 const RunTimeOperandInfo* weight, int32_t* outBuffer); 188 template void DenseLshProjection<_Float16>(const RunTimeOperandInfo* hash, 189 const RunTimeOperandInfo* input, 190 const RunTimeOperandInfo* weight, int32_t* outBuffer); 191 192 } // namespace nn 193 } // namespace android 194