Home | History | Annotate | Download | only in nnCache
      1 /*
      2  ** Copyright 2011, The Android Open Source Project
      3  **
      4  ** Licensed under the Apache License, Version 2.0 (the "License");
      5  ** you may not use this file except in compliance with the License.
      6  ** You may obtain a copy of the License at
      7  **
      8  **     http://www.apache.org/licenses/LICENSE-2.0
      9  **
     10  ** Unless required by applicable law or agreed to in writing, software
     11  ** distributed under the License is distributed on an "AS IS" BASIS,
     12  ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  ** See the License for the specific language governing permissions and
     14  ** limitations under the License.
     15  */
     16 
     17 #include "nnCache.h"
     18 
     19 #include <inttypes.h>
     20 #include <sys/mman.h>
     21 #include <sys/stat.h>
     22 #include <unistd.h>
     23 
     24 #include <thread>
     25 
     26 #include <log/log.h>
     27 
     28 // Cache file header
     29 static const char* cacheFileMagic = "nn$$";
     30 static const size_t cacheFileHeaderSize = 8;
     31 
     32 // The time in seconds to wait before saving newly inserted cache entries.
     33 static const unsigned int deferredSaveDelay = 4;
     34 
     35 // ----------------------------------------------------------------------------
     36 namespace android {
     37 // ----------------------------------------------------------------------------
     38 
     39 //
     40 // NNCache definition
     41 //
     42 NNCache::NNCache() :
     43     mInitialized(false),
     44     mMaxKeySize(0), mMaxValueSize(0), mMaxTotalSize(0),
     45     mPolicy(defaultPolicy()),
     46     mSavePending(false) {
     47 }
     48 
     49 NNCache::~NNCache() {
     50 }
     51 
     52 NNCache NNCache::sCache;
     53 
     54 NNCache* NNCache::get() {
     55     return &sCache;
     56 }
     57 
     58 void NNCache::initialize(size_t maxKeySize, size_t maxValueSize, size_t maxTotalSize,
     59                          Policy policy) {
     60     std::lock_guard<std::mutex> lock(mMutex);
     61     mInitialized = true;
     62     mMaxKeySize = maxKeySize;
     63     mMaxValueSize = maxValueSize;
     64     mMaxTotalSize = maxTotalSize;
     65     mPolicy = policy;
     66 }
     67 
     68 void NNCache::terminate() {
     69     std::lock_guard<std::mutex> lock(mMutex);
     70     saveBlobCacheLocked();
     71     mBlobCache = NULL;
     72     mInitialized = false;
     73 }
     74 
     75 void NNCache::setBlob(const void* key, ssize_t keySize,
     76         const void* value, ssize_t valueSize) {
     77     std::lock_guard<std::mutex> lock(mMutex);
     78 
     79     if (keySize < 0 || valueSize < 0) {
     80         ALOGW("nnCache::setBlob: negative sizes are not allowed");
     81         return;
     82     }
     83 
     84     if (mInitialized) {
     85         BlobCache* bc = getBlobCacheLocked();
     86         bc->set(key, keySize, value, valueSize);
     87 
     88         if (!mSavePending) {
     89             mSavePending = true;
     90             std::thread deferredSaveThread([this]() {
     91                 sleep(deferredSaveDelay);
     92                 std::lock_guard<std::mutex> lock(mMutex);
     93                 if (mInitialized) {
     94                     saveBlobCacheLocked();
     95                 }
     96                 mSavePending = false;
     97             });
     98             deferredSaveThread.detach();
     99         }
    100     }
    101 }
    102 
    103 ssize_t NNCache::getBlob(const void* key, ssize_t keySize,
    104         void* value, ssize_t valueSize) {
    105     std::lock_guard<std::mutex> lock(mMutex);
    106 
    107     if (keySize < 0 || valueSize < 0) {
    108         ALOGW("nnCache::getBlob: negative sizes are not allowed");
    109         return 0;
    110     }
    111 
    112     if (mInitialized) {
    113         BlobCache* bc = getBlobCacheLocked();
    114         return bc->get(key, keySize, value, valueSize);
    115     }
    116     return 0;
    117 }
    118 
    119 ssize_t NNCache::getBlob(const void* key, ssize_t keySize,
    120         void** value, std::function<void*(size_t)> alloc) {
    121     std::lock_guard<std::mutex> lock(mMutex);
    122 
    123     if (keySize < 0) {
    124         ALOGW("nnCache::getBlob: negative sizes are not allowed");
    125         return 0;
    126     }
    127 
    128     if (mInitialized) {
    129         BlobCache* bc = getBlobCacheLocked();
    130         return bc->get(key, keySize, value, alloc);
    131     }
    132     return 0;
    133 }
    134 
    135 void NNCache::setCacheFilename(const char* filename) {
    136     std::lock_guard<std::mutex> lock(mMutex);
    137     mFilename = filename;
    138 }
    139 
    140 BlobCache* NNCache::getBlobCacheLocked() {
    141     if (mBlobCache == nullptr) {
    142         mBlobCache.reset(new BlobCache(mMaxKeySize, mMaxValueSize, mMaxTotalSize, mPolicy));
    143         loadBlobCacheLocked();
    144     }
    145     return mBlobCache.get();
    146 }
    147 
    148 static uint32_t crc32c(const uint8_t* buf, size_t len) {
    149     const uint32_t polyBits = 0x82F63B78;
    150     uint32_t r = 0;
    151     for (size_t i = 0; i < len; i++) {
    152         r ^= buf[i];
    153         for (int j = 0; j < 8; j++) {
    154             if (r & 1) {
    155                 r = (r >> 1) ^ polyBits;
    156             } else {
    157                 r >>= 1;
    158             }
    159         }
    160     }
    161     return r;
    162 }
    163 
    164 void NNCache::saveBlobCacheLocked() {
    165     if (mFilename.length() > 0 && mBlobCache != NULL) {
    166         size_t cacheSize = mBlobCache->getFlattenedSize();
    167         size_t headerSize = cacheFileHeaderSize;
    168         const char* fname = mFilename.c_str();
    169 
    170         // Try to create the file with no permissions so we can write it
    171         // without anyone trying to read it.
    172         int fd = open(fname, O_CREAT | O_EXCL | O_RDWR, 0);
    173         if (fd == -1) {
    174             if (errno == EEXIST) {
    175                 // The file exists, delete it and try again.
    176                 if (unlink(fname) == -1) {
    177                     // No point in retrying if the unlink failed.
    178                     ALOGE("error unlinking cache file %s: %s (%d)", fname,
    179                             strerror(errno), errno);
    180                     return;
    181                 }
    182                 // Retry now that we've unlinked the file.
    183                 fd = open(fname, O_CREAT | O_EXCL | O_RDWR, 0);
    184             }
    185             if (fd == -1) {
    186                 ALOGE("error creating cache file %s: %s (%d)", fname,
    187                         strerror(errno), errno);
    188                 return;
    189             }
    190         }
    191 
    192         size_t fileSize = headerSize + cacheSize;
    193 
    194         uint8_t* buf = new uint8_t [fileSize];
    195         if (!buf) {
    196             ALOGE("error allocating buffer for cache contents: %s (%d)",
    197                     strerror(errno), errno);
    198             close(fd);
    199             unlink(fname);
    200             return;
    201         }
    202 
    203         int err = mBlobCache->flatten(buf + headerSize, cacheSize);
    204         if (err < 0) {
    205             ALOGE("error writing cache contents: %s (%d)", strerror(-err),
    206                     -err);
    207             delete [] buf;
    208             close(fd);
    209             unlink(fname);
    210             return;
    211         }
    212 
    213         // Write the file magic and CRC
    214         memcpy(buf, cacheFileMagic, 4);
    215         uint32_t* crc = reinterpret_cast<uint32_t*>(buf + 4);
    216         *crc = crc32c(buf + headerSize, cacheSize);
    217 
    218         if (write(fd, buf, fileSize) == -1) {
    219             ALOGE("error writing cache file: %s (%d)", strerror(errno),
    220                     errno);
    221             delete [] buf;
    222             close(fd);
    223             unlink(fname);
    224             return;
    225         }
    226 
    227         delete [] buf;
    228         fchmod(fd, S_IRUSR);
    229         close(fd);
    230     }
    231 }
    232 
    233 void NNCache::loadBlobCacheLocked() {
    234     if (mFilename.length() > 0) {
    235         size_t headerSize = cacheFileHeaderSize;
    236 
    237         int fd = open(mFilename.c_str(), O_RDONLY, 0);
    238         if (fd == -1) {
    239             if (errno != ENOENT) {
    240                 ALOGE("error opening cache file %s: %s (%d)", mFilename.c_str(),
    241                         strerror(errno), errno);
    242             }
    243             return;
    244         }
    245 
    246         struct stat statBuf;
    247         if (fstat(fd, &statBuf) == -1) {
    248             ALOGE("error stat'ing cache file: %s (%d)", strerror(errno), errno);
    249             close(fd);
    250             return;
    251         }
    252 
    253         // Sanity check the size before trying to mmap it.
    254         size_t fileSize = statBuf.st_size;
    255         if (fileSize > mMaxTotalSize * 2) {
    256             ALOGE("cache file is too large: %#" PRIx64,
    257                   static_cast<off64_t>(statBuf.st_size));
    258             close(fd);
    259             return;
    260         }
    261 
    262         uint8_t* buf = reinterpret_cast<uint8_t*>(mmap(NULL, fileSize,
    263                 PROT_READ, MAP_PRIVATE, fd, 0));
    264         if (buf == MAP_FAILED) {
    265             ALOGE("error mmaping cache file: %s (%d)", strerror(errno),
    266                     errno);
    267             close(fd);
    268             return;
    269         }
    270 
    271         // Check the file magic and CRC
    272         size_t cacheSize = fileSize - headerSize;
    273         if (memcmp(buf, cacheFileMagic, 4) != 0) {
    274             ALOGE("cache file has bad mojo");
    275             close(fd);
    276             return;
    277         }
    278         uint32_t* crc = reinterpret_cast<uint32_t*>(buf + 4);
    279         if (crc32c(buf + headerSize, cacheSize) != *crc) {
    280             ALOGE("cache file failed CRC check");
    281             close(fd);
    282             return;
    283         }
    284 
    285         int err = mBlobCache->unflatten(buf + headerSize, cacheSize);
    286         if (err < 0) {
    287             ALOGE("error reading cache contents: %s (%d)", strerror(-err),
    288                     -err);
    289             munmap(buf, fileSize);
    290             close(fd);
    291             return;
    292         }
    293 
    294         munmap(buf, fileSize);
    295         close(fd);
    296     }
    297 }
    298 
    299 // ----------------------------------------------------------------------------
    300 }; // namespace android
    301 // ----------------------------------------------------------------------------
    302