Home | History | Annotate | Download | only in keystore
      1 /*
      2  * Copyright (C) 2014 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include <keystore/authorization_set.h>
     18 
     19 #include <assert.h>
     20 #include <stddef.h>
     21 #include <stdlib.h>
     22 #include <string.h>
     23 #include <limits>
     24 #include <ostream>
     25 #include <istream>
     26 
     27 #include <new>
     28 
     29 namespace keystore {
     30 
     31 inline bool keyParamLess(const KeyParameter& a, const KeyParameter& b) {
     32     if (a.tag != b.tag) return a.tag < b.tag;
     33     int retval;
     34     switch (typeFromTag(a.tag)) {
     35     case TagType::INVALID:
     36     case TagType::BOOL:
     37         return false;
     38     case TagType::ENUM:
     39     case TagType::ENUM_REP:
     40     case TagType::UINT:
     41     case TagType::UINT_REP:
     42         return a.f.integer < b.f.integer;
     43     case TagType::ULONG:
     44     case TagType::ULONG_REP:
     45         return a.f.longInteger < b.f.longInteger;
     46     case TagType::DATE:
     47         return a.f.dateTime < b.f.dateTime;
     48     case TagType::BIGNUM:
     49     case TagType::BYTES:
     50         // Handle the empty cases.
     51         if (a.blob.size() == 0)
     52             return b.blob.size() != 0;
     53         if (b.blob.size() == 0) return false;
     54 
     55         retval = memcmp(&a.blob[0], &b.blob[0], std::min(a.blob.size(), b.blob.size()));
     56         // if one is the prefix of the other the longer wins
     57         if (retval == 0) return a.blob.size() < b.blob.size();
     58         // Otherwise a is less if a is less.
     59         else return retval < 0;
     60     }
     61     return false;
     62 }
     63 
     64 inline bool keyParamEqual(const KeyParameter& a, const KeyParameter& b) {
     65     if (a.tag != b.tag) return false;
     66 
     67     switch (typeFromTag(a.tag)) {
     68     case TagType::INVALID:
     69     case TagType::BOOL:
     70         return true;
     71     case TagType::ENUM:
     72     case TagType::ENUM_REP:
     73     case TagType::UINT:
     74     case TagType::UINT_REP:
     75         return a.f.integer == b.f.integer;
     76     case TagType::ULONG:
     77     case TagType::ULONG_REP:
     78         return a.f.longInteger == b.f.longInteger;
     79     case TagType::DATE:
     80         return a.f.dateTime == b.f.dateTime;
     81     case TagType::BIGNUM:
     82     case TagType::BYTES:
     83         if (a.blob.size() != b.blob.size()) return false;
     84         return a.blob.size() == 0 ||
     85                 memcmp(&a.blob[0], &b.blob[0], a.blob.size()) == 0;
     86     }
     87     return false;
     88 }
     89 
     90 void AuthorizationSet::Sort() {
     91     std::sort(data_.begin(), data_.end(), keyParamLess);
     92 }
     93 
     94 void AuthorizationSet::Deduplicate() {
     95     if (data_.empty()) return;
     96 
     97     Sort();
     98     std::vector<KeyParameter> result;
     99 
    100     auto curr = data_.begin();
    101     auto prev = curr++;
    102     for (; curr != data_.end(); ++prev, ++curr) {
    103         if (prev->tag == Tag::INVALID) continue;
    104 
    105         if (!keyParamEqual(*prev, *curr)) {
    106             result.emplace_back(std::move(*prev));
    107         }
    108     }
    109     result.emplace_back(std::move(*prev));
    110 
    111     std::swap(data_, result);
    112 }
    113 
    114 void AuthorizationSet::Union(const AuthorizationSet& other) {
    115     data_.insert(data_.end(), other.data_.begin(), other.data_.end());
    116     Deduplicate();
    117 }
    118 
    119 void AuthorizationSet::Subtract(const AuthorizationSet& other) {
    120     Deduplicate();
    121 
    122     auto i = other.begin();
    123     while (i != other.end()) {
    124         int pos = -1;
    125         do {
    126             pos = find(i->tag, pos);
    127             if (pos != -1 && keyParamEqual(*i, data_[pos])) {
    128                 data_.erase(data_.begin() + pos);
    129                 break;
    130             }
    131         } while (pos != -1);
    132         ++i;
    133     }
    134 }
    135 
    136 int AuthorizationSet::find(Tag tag, int begin) const {
    137     auto iter = data_.begin() + (1 + begin);
    138 
    139     while (iter != data_.end() && iter->tag != tag) ++iter;
    140 
    141     if (iter != data_.end()) return iter - data_.begin();
    142     return -1;
    143 }
    144 
    145 bool AuthorizationSet::erase(int index) {
    146     auto pos = data_.begin() + index;
    147     if (pos != data_.end()) {
    148         data_.erase(pos);
    149         return true;
    150     }
    151     return false;
    152 }
    153 
    154 KeyParameter& AuthorizationSet::operator[](int at) {
    155     return data_[at];
    156 }
    157 
    158 const KeyParameter& AuthorizationSet::operator[](int at) const {
    159     return data_[at];
    160 }
    161 
    162 void AuthorizationSet::Clear() {
    163     data_.clear();
    164 }
    165 
    166 size_t AuthorizationSet::GetTagCount(Tag tag) const {
    167     size_t count = 0;
    168     for (int pos = -1; (pos = find(tag, pos)) != -1;)
    169         ++count;
    170     return count;
    171 }
    172 
    173 NullOr<const KeyParameter&> AuthorizationSet::GetEntry(Tag tag) const {
    174     int pos = find(tag);
    175     if (pos == -1) return {};
    176     return data_[pos];
    177 }
    178 
    179 /**
    180  * Persistent format is:
    181  * | 32 bit indirect_size         |
    182  * --------------------------------
    183  * | indirect_size bytes of data  | this is where the blob data is stored
    184  * --------------------------------
    185  * | 32 bit element_count         | number of entries
    186  * | 32 bit elements_size         | total bytes used by entries (entries have variable length)
    187  * --------------------------------
    188  * | elementes_size bytes of data | where the elements are stored
    189  */
    190 
    191 /**
    192  * Persistent format of blobs and bignums:
    193  * | 32 bit tag             |
    194  * | 32 bit blob_length     |
    195  * | 32 bit indirect_offset |
    196  */
    197 
    198 struct OutStreams {
    199     std::ostream& indirect;
    200     std::ostream& elements;
    201 };
    202 
    203 OutStreams& serializeParamValue(OutStreams& out, const hidl_vec<uint8_t>& blob) {
    204     uint32_t buffer;
    205 
    206     // write blob_length
    207     auto blob_length = blob.size();
    208     if (blob_length > std::numeric_limits<uint32_t>::max()) {
    209         out.elements.setstate(std::ios_base::badbit);
    210         return out;
    211     }
    212     buffer = blob_length;
    213     out.elements.write(reinterpret_cast<const char*>(&buffer), sizeof(uint32_t));
    214 
    215     // write indirect_offset
    216     auto offset = out.indirect.tellp();
    217     if (offset < 0 || offset > std::numeric_limits<uint32_t>::max() ||
    218             uint32_t(offset) + uint32_t(blob_length) < uint32_t(offset)) { // overflow check
    219         out.elements.setstate(std::ios_base::badbit);
    220         return out;
    221     }
    222     buffer = offset;
    223     out.elements.write(reinterpret_cast<const char*>(&buffer), sizeof(uint32_t));
    224 
    225     // write blob to indirect stream
    226     if(blob_length)
    227         out.indirect.write(reinterpret_cast<const char*>(&blob[0]), blob_length);
    228 
    229     return out;
    230 }
    231 
    232 template <typename T>
    233 OutStreams& serializeParamValue(OutStreams& out, const T& value) {
    234     out.elements.write(reinterpret_cast<const char*>(&value), sizeof(T));
    235     return out;
    236 }
    237 
    238 OutStreams& serialize(TAG_INVALID_t&&, OutStreams& out, const KeyParameter&) {
    239     // skip invalid entries.
    240     return out;
    241 }
    242 template <typename T>
    243 OutStreams& serialize(T ttag, OutStreams& out, const KeyParameter& param) {
    244     out.elements.write(reinterpret_cast<const char*>(&param.tag), sizeof(int32_t));
    245     return serializeParamValue(out, accessTagValue(ttag, param));
    246 }
    247 
    248 template <typename... T>
    249 struct choose_serializer;
    250 template <typename... Tags>
    251 struct choose_serializer<MetaList<Tags...>> {
    252     static OutStreams& serialize(OutStreams& out, const KeyParameter& param) {
    253         return choose_serializer<Tags...>::serialize(out, param);
    254     }
    255 };
    256 template <>
    257 struct choose_serializer<> {
    258     static OutStreams& serialize(OutStreams& out, const KeyParameter&) {
    259         return out;
    260     }
    261 };
    262 template <TagType tag_type, Tag tag, typename... Tail>
    263 struct choose_serializer<TypedTag<tag_type, tag>, Tail...> {
    264     static OutStreams& serialize(OutStreams& out, const KeyParameter& param) {
    265         if (param.tag == tag) {
    266             return keystore::serialize(TypedTag<tag_type, tag>(), out, param);
    267         } else {
    268             return choose_serializer<Tail...>::serialize(out, param);
    269         }
    270     }
    271 };
    272 
    273 OutStreams& serialize(OutStreams& out, const KeyParameter& param) {
    274     return choose_serializer<all_tags_t>::serialize(out, param);
    275 }
    276 
    277 std::ostream& serialize(std::ostream& out, const std::vector<KeyParameter>& params) {
    278     std::stringstream indirect;
    279     std::stringstream elements;
    280     OutStreams streams = { indirect, elements };
    281     for (const auto& param: params) {
    282         serialize(streams, param);
    283     }
    284     if (indirect.bad() || elements.bad()) {
    285         out.setstate(std::ios_base::badbit);
    286         return out;
    287     }
    288     auto pos = indirect.tellp();
    289     if (pos < 0 || pos > std::numeric_limits<uint32_t>::max()) {
    290         out.setstate(std::ios_base::badbit);
    291         return out;
    292     }
    293     uint32_t indirect_size = pos;
    294     pos = elements.tellp();
    295     if (pos < 0 || pos > std::numeric_limits<uint32_t>::max()) {
    296         out.setstate(std::ios_base::badbit);
    297         return out;
    298     }
    299     uint32_t elements_size = pos;
    300     uint32_t element_count = params.size();
    301 
    302     out.write(reinterpret_cast<const char*>(&indirect_size), sizeof(uint32_t));
    303 
    304     pos = out.tellp();
    305     if (indirect_size)
    306         out << indirect.rdbuf();
    307     assert(out.tellp() - pos == indirect_size);
    308 
    309     out.write(reinterpret_cast<const char*>(&element_count), sizeof(uint32_t));
    310     out.write(reinterpret_cast<const char*>(&elements_size), sizeof(uint32_t));
    311 
    312     pos = out.tellp();
    313     if (elements_size)
    314         out << elements.rdbuf();
    315     assert(out.tellp() - pos == elements_size);
    316 
    317     return out;
    318 }
    319 
    320 struct InStreams {
    321     std::istream& indirect;
    322     std::istream& elements;
    323 };
    324 
    325 InStreams& deserializeParamValue(InStreams& in, hidl_vec<uint8_t>* blob) {
    326     uint32_t blob_length = 0;
    327     uint32_t offset = 0;
    328     in.elements.read(reinterpret_cast<char*>(&blob_length), sizeof(uint32_t));
    329     blob->resize(blob_length);
    330     in.elements.read(reinterpret_cast<char*>(&offset), sizeof(uint32_t));
    331     in.indirect.seekg(offset);
    332     in.indirect.read(reinterpret_cast<char*>(&(*blob)[0]), blob->size());
    333     return in;
    334 }
    335 
    336 template <typename T>
    337 InStreams& deserializeParamValue(InStreams& in, T* value) {
    338     in.elements.read(reinterpret_cast<char*>(value), sizeof(T));
    339     return in;
    340 }
    341 
    342 InStreams& deserialize(TAG_INVALID_t&&, InStreams& in, KeyParameter*) {
    343     // there should be no invalid KeyParamaters but if handle them as zero sized.
    344     return in;
    345 }
    346 
    347 template <typename T>
    348 InStreams& deserialize(T&& ttag, InStreams& in, KeyParameter* param) {
    349     return deserializeParamValue(in, &accessTagValue(ttag, *param));
    350 }
    351 
    352 template <typename... T>
    353 struct choose_deserializer;
    354 template <typename... Tags>
    355 struct choose_deserializer<MetaList<Tags...>> {
    356     static InStreams& deserialize(InStreams& in, KeyParameter* param) {
    357         return choose_deserializer<Tags...>::deserialize(in, param);
    358     }
    359 };
    360 template <>
    361 struct choose_deserializer<> {
    362     static InStreams& deserialize(InStreams& in, KeyParameter*) {
    363         // encountered an unknown tag -> fail parsing
    364         in.elements.setstate(std::ios_base::badbit);
    365         return in;
    366     }
    367 };
    368 template <TagType tag_type, Tag tag, typename... Tail>
    369 struct choose_deserializer<TypedTag<tag_type, tag>, Tail...> {
    370     static InStreams& deserialize(InStreams& in, KeyParameter* param) {
    371         if (param->tag == tag) {
    372             return keystore::deserialize(TypedTag<tag_type, tag>(), in, param);
    373         } else {
    374             return choose_deserializer<Tail...>::deserialize(in, param);
    375         }
    376     }
    377 };
    378 
    379 InStreams& deserialize(InStreams& in, KeyParameter* param) {
    380     in.elements.read(reinterpret_cast<char*>(&param->tag), sizeof(Tag));
    381     return choose_deserializer<all_tags_t>::deserialize(in, param);
    382 }
    383 
    384 std::istream& deserialize(std::istream& in, std::vector<KeyParameter>* params) {
    385     uint32_t indirect_size = 0;
    386     in.read(reinterpret_cast<char*>(&indirect_size), sizeof(uint32_t));
    387     std::string indirect_buffer(indirect_size, '\0');
    388     if (indirect_buffer.size() != indirect_size) {
    389         in.setstate(std::ios_base::badbit);
    390         return in;
    391     }
    392     in.read(&indirect_buffer[0], indirect_buffer.size());
    393 
    394     uint32_t element_count = 0;
    395     in.read(reinterpret_cast<char*>(&element_count), sizeof(uint32_t));
    396     uint32_t elements_size = 0;
    397     in.read(reinterpret_cast<char*>(&elements_size), sizeof(uint32_t));
    398 
    399     std::string elements_buffer(elements_size, '\0');
    400     if(elements_buffer.size() != elements_size) {
    401         in.setstate(std::ios_base::badbit);
    402         return in;
    403     }
    404     in.read(&elements_buffer[0], elements_buffer.size());
    405 
    406     if (in.bad()) return in;
    407 
    408     // TODO write one-shot stream buffer to avoid copying here
    409     std::stringstream indirect(indirect_buffer);
    410     std::stringstream elements(elements_buffer);
    411     InStreams streams = { indirect, elements };
    412 
    413     params->resize(element_count);
    414 
    415     for (uint32_t i = 0; i < element_count; ++i) {
    416         deserialize(streams, &(*params)[i]);
    417     }
    418     return in;
    419 }
    420 void AuthorizationSet::Serialize(std::ostream* out) const {
    421     serialize(*out, data_);
    422 }
    423 void AuthorizationSet::Deserialize(std::istream* in) {
    424     deserialize(*in, &data_);
    425 }
    426 
    427 }  // namespace keystore
    428