Home | History | Annotate | Download | only in functional
      1 /*
      2  * Copyright (C) 2014 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "authorization_set.h"
     18 
     19 #include <assert.h>
     20 #include <istream>
     21 #include <limits>
     22 #include <ostream>
     23 #include <stddef.h>
     24 #include <stdlib.h>
     25 #include <string.h>
     26 
     27 #include <new>
     28 
     29 namespace android {
     30 namespace hardware {
     31 namespace keymaster {
     32 namespace V3_0 {
     33 
     34 inline bool keyParamLess(const KeyParameter& a, const KeyParameter& b) {
     35     if (a.tag != b.tag) return a.tag < b.tag;
     36     int retval;
     37     switch (typeFromTag(a.tag)) {
     38     case TagType::INVALID:
     39     case TagType::BOOL:
     40         return false;
     41     case TagType::ENUM:
     42     case TagType::ENUM_REP:
     43     case TagType::UINT:
     44     case TagType::UINT_REP:
     45         return a.f.integer < b.f.integer;
     46     case TagType::ULONG:
     47     case TagType::ULONG_REP:
     48         return a.f.longInteger < b.f.longInteger;
     49     case TagType::DATE:
     50         return a.f.dateTime < b.f.dateTime;
     51     case TagType::BIGNUM:
     52     case TagType::BYTES:
     53         // Handle the empty cases.
     54         if (a.blob.size() == 0) return b.blob.size() != 0;
     55         if (b.blob.size() == 0) return false;
     56 
     57         retval = memcmp(&a.blob[0], &b.blob[0], std::min(a.blob.size(), b.blob.size()));
     58         if (retval == 0) {
     59             // One is the prefix of the other, so the longer wins
     60             return a.blob.size() < b.blob.size();
     61         } else {
     62             return retval < 0;
     63         }
     64     }
     65     return false;
     66 }
     67 
     68 inline bool keyParamEqual(const KeyParameter& a, const KeyParameter& b) {
     69     if (a.tag != b.tag) return false;
     70 
     71     switch (typeFromTag(a.tag)) {
     72     case TagType::INVALID:
     73     case TagType::BOOL:
     74         return true;
     75     case TagType::ENUM:
     76     case TagType::ENUM_REP:
     77     case TagType::UINT:
     78     case TagType::UINT_REP:
     79         return a.f.integer == b.f.integer;
     80     case TagType::ULONG:
     81     case TagType::ULONG_REP:
     82         return a.f.longInteger == b.f.longInteger;
     83     case TagType::DATE:
     84         return a.f.dateTime == b.f.dateTime;
     85     case TagType::BIGNUM:
     86     case TagType::BYTES:
     87         if (a.blob.size() != b.blob.size()) return false;
     88         return a.blob.size() == 0 || memcmp(&a.blob[0], &b.blob[0], a.blob.size()) == 0;
     89     }
     90     return false;
     91 }
     92 
     93 void AuthorizationSet::Sort() {
     94     std::sort(data_.begin(), data_.end(), keyParamLess);
     95 }
     96 
     97 void AuthorizationSet::Deduplicate() {
     98     if (data_.empty()) return;
     99 
    100     Sort();
    101     std::vector<KeyParameter> result;
    102 
    103     auto curr = data_.begin();
    104     auto prev = curr++;
    105     for (; curr != data_.end(); ++prev, ++curr) {
    106         if (prev->tag == Tag::INVALID) continue;
    107 
    108         if (!keyParamEqual(*prev, *curr)) {
    109             result.emplace_back(std::move(*prev));
    110         }
    111     }
    112     result.emplace_back(std::move(*prev));
    113 
    114     std::swap(data_, result);
    115 }
    116 
    117 void AuthorizationSet::Union(const AuthorizationSet& other) {
    118     data_.insert(data_.end(), other.data_.begin(), other.data_.end());
    119     Deduplicate();
    120 }
    121 
    122 void AuthorizationSet::Subtract(const AuthorizationSet& other) {
    123     Deduplicate();
    124 
    125     auto i = other.begin();
    126     while (i != other.end()) {
    127         int pos = -1;
    128         do {
    129             pos = find(i->tag, pos);
    130             if (pos != -1 && keyParamEqual(*i, data_[pos])) {
    131                 data_.erase(data_.begin() + pos);
    132                 break;
    133             }
    134         } while (pos != -1);
    135         ++i;
    136     }
    137 }
    138 
    139 int AuthorizationSet::find(Tag tag, int begin) const {
    140     auto iter = data_.begin() + (1 + begin);
    141 
    142     while (iter != data_.end() && iter->tag != tag)
    143         ++iter;
    144 
    145     if (iter != data_.end()) return iter - data_.begin();
    146     return -1;
    147 }
    148 
    149 bool AuthorizationSet::erase(int index) {
    150     auto pos = data_.begin() + index;
    151     if (pos != data_.end()) {
    152         data_.erase(pos);
    153         return true;
    154     }
    155     return false;
    156 }
    157 
    158 KeyParameter& AuthorizationSet::operator[](int at) {
    159     return data_[at];
    160 }
    161 
    162 const KeyParameter& AuthorizationSet::operator[](int at) const {
    163     return data_[at];
    164 }
    165 
    166 void AuthorizationSet::Clear() {
    167     data_.clear();
    168 }
    169 
    170 size_t AuthorizationSet::GetTagCount(Tag tag) const {
    171     size_t count = 0;
    172     for (int pos = -1; (pos = find(tag, pos)) != -1;)
    173         ++count;
    174     return count;
    175 }
    176 
    177 NullOr<const KeyParameter&> AuthorizationSet::GetEntry(Tag tag) const {
    178     int pos = find(tag);
    179     if (pos == -1) return {};
    180     return data_[pos];
    181 }
    182 
    183 /**
    184  * Persistent format is:
    185  * | 32 bit indirect_size         |
    186  * --------------------------------
    187  * | indirect_size bytes of data  | this is where the blob data is stored
    188  * --------------------------------
    189  * | 32 bit element_count         | number of entries
    190  * | 32 bit elements_size         | total bytes used by entries (entries have variable length)
    191  * --------------------------------
    192  * | elementes_size bytes of data | where the elements are stored
    193  */
    194 
    195 /**
    196  * Persistent format of blobs and bignums:
    197  * | 32 bit tag             |
    198  * | 32 bit blob_length     |
    199  * | 32 bit indirect_offset |
    200  */
    201 
    202 struct OutStreams {
    203     std::ostream& indirect;
    204     std::ostream& elements;
    205 };
    206 
    207 OutStreams& serializeParamValue(OutStreams& out, const hidl_vec<uint8_t>& blob) {
    208     uint32_t buffer;
    209 
    210     // write blob_length
    211     auto blob_length = blob.size();
    212     if (blob_length > std::numeric_limits<uint32_t>::max()) {
    213         out.elements.setstate(std::ios_base::badbit);
    214         return out;
    215     }
    216     buffer = blob_length;
    217     out.elements.write(reinterpret_cast<const char*>(&buffer), sizeof(uint32_t));
    218 
    219     // write indirect_offset
    220     auto offset = out.indirect.tellp();
    221     if (offset < 0 || offset > std::numeric_limits<uint32_t>::max() ||
    222         static_cast<uint32_t>((std::numeric_limits<uint32_t>::max() - offset)) <
    223             blob_length) {  // overflow check
    224         out.elements.setstate(std::ios_base::badbit);
    225         return out;
    226     }
    227     buffer = offset;
    228     out.elements.write(reinterpret_cast<const char*>(&buffer), sizeof(uint32_t));
    229 
    230     // write blob to indirect stream
    231     if (blob_length) out.indirect.write(reinterpret_cast<const char*>(&blob[0]), blob_length);
    232 
    233     return out;
    234 }
    235 
    236 template <typename T> OutStreams& serializeParamValue(OutStreams& out, const T& value) {
    237     out.elements.write(reinterpret_cast<const char*>(&value), sizeof(T));
    238     return out;
    239 }
    240 
    241 OutStreams& serialize(TAG_INVALID_t&&, OutStreams& out, const KeyParameter&) {
    242     // skip invalid entries.
    243     return out;
    244 }
    245 template <typename T> OutStreams& serialize(T ttag, OutStreams& out, const KeyParameter& param) {
    246     out.elements.write(reinterpret_cast<const char*>(&param.tag), sizeof(int32_t));
    247     return serializeParamValue(out, accessTagValue(ttag, param));
    248 }
    249 
    250 template <typename... T> struct choose_serializer;
    251 template <typename... Tags> struct choose_serializer<MetaList<Tags...>> {
    252     static OutStreams& serialize(OutStreams& out, const KeyParameter& param) {
    253         return choose_serializer<Tags...>::serialize(out, param);
    254     }
    255 };
    256 template <> struct choose_serializer<> {
    257     static OutStreams& serialize(OutStreams& out, const KeyParameter&) { return out; }
    258 };
    259 template <TagType tag_type, Tag tag, typename... Tail>
    260 struct choose_serializer<TypedTag<tag_type, tag>, Tail...> {
    261     static OutStreams& serialize(OutStreams& out, const KeyParameter& param) {
    262         if (param.tag == tag) {
    263             return V3_0::serialize(TypedTag<tag_type, tag>(), out, param);
    264         } else {
    265             return choose_serializer<Tail...>::serialize(out, param);
    266         }
    267     }
    268 };
    269 
    270 OutStreams& serialize(OutStreams& out, const KeyParameter& param) {
    271     return choose_serializer<all_tags_t>::serialize(out, param);
    272 }
    273 
    274 std::ostream& serialize(std::ostream& out, const std::vector<KeyParameter>& params) {
    275     std::stringstream indirect;
    276     std::stringstream elements;
    277     OutStreams streams = {indirect, elements};
    278     for (const auto& param : params) {
    279         serialize(streams, param);
    280     }
    281     if (indirect.bad() || elements.bad()) {
    282         out.setstate(std::ios_base::badbit);
    283         return out;
    284     }
    285     auto pos = indirect.tellp();
    286     if (pos < 0 || pos > std::numeric_limits<uint32_t>::max()) {
    287         out.setstate(std::ios_base::badbit);
    288         return out;
    289     }
    290     uint32_t indirect_size = pos;
    291     pos = elements.tellp();
    292     if (pos < 0 || pos > std::numeric_limits<uint32_t>::max()) {
    293         out.setstate(std::ios_base::badbit);
    294         return out;
    295     }
    296     uint32_t elements_size = pos;
    297     uint32_t element_count = params.size();
    298 
    299     out.write(reinterpret_cast<const char*>(&indirect_size), sizeof(uint32_t));
    300 
    301     pos = out.tellp();
    302     if (indirect_size) out << indirect.rdbuf();
    303     assert(out.tellp() - pos == indirect_size);
    304 
    305     out.write(reinterpret_cast<const char*>(&element_count), sizeof(uint32_t));
    306     out.write(reinterpret_cast<const char*>(&elements_size), sizeof(uint32_t));
    307 
    308     pos = out.tellp();
    309     if (elements_size) out << elements.rdbuf();
    310     assert(out.tellp() - pos == elements_size);
    311 
    312     return out;
    313 }
    314 
    315 struct InStreams {
    316     std::istream& indirect;
    317     std::istream& elements;
    318 };
    319 
    320 InStreams& deserializeParamValue(InStreams& in, hidl_vec<uint8_t>* blob) {
    321     uint32_t blob_length = 0;
    322     uint32_t offset = 0;
    323     in.elements.read(reinterpret_cast<char*>(&blob_length), sizeof(uint32_t));
    324     blob->resize(blob_length);
    325     in.elements.read(reinterpret_cast<char*>(&offset), sizeof(uint32_t));
    326     in.indirect.seekg(offset);
    327     in.indirect.read(reinterpret_cast<char*>(&(*blob)[0]), blob->size());
    328     return in;
    329 }
    330 
    331 template <typename T> InStreams& deserializeParamValue(InStreams& in, T* value) {
    332     in.elements.read(reinterpret_cast<char*>(value), sizeof(T));
    333     return in;
    334 }
    335 
    336 InStreams& deserialize(TAG_INVALID_t&&, InStreams& in, KeyParameter*) {
    337     // there should be no invalid KeyParamaters but if handle them as zero sized.
    338     return in;
    339 }
    340 
    341 template <typename T> InStreams& deserialize(T&& ttag, InStreams& in, KeyParameter* param) {
    342     return deserializeParamValue(in, &accessTagValue(ttag, *param));
    343 }
    344 
    345 template <typename... T> struct choose_deserializer;
    346 template <typename... Tags> struct choose_deserializer<MetaList<Tags...>> {
    347     static InStreams& deserialize(InStreams& in, KeyParameter* param) {
    348         return choose_deserializer<Tags...>::deserialize(in, param);
    349     }
    350 };
    351 template <> struct choose_deserializer<> {
    352     static InStreams& deserialize(InStreams& in, KeyParameter*) {
    353         // encountered an unknown tag -> fail parsing
    354         in.elements.setstate(std::ios_base::badbit);
    355         return in;
    356     }
    357 };
    358 template <TagType tag_type, Tag tag, typename... Tail>
    359 struct choose_deserializer<TypedTag<tag_type, tag>, Tail...> {
    360     static InStreams& deserialize(InStreams& in, KeyParameter* param) {
    361         if (param->tag == tag) {
    362             return V3_0::deserialize(TypedTag<tag_type, tag>(), in, param);
    363         } else {
    364             return choose_deserializer<Tail...>::deserialize(in, param);
    365         }
    366     }
    367 };
    368 
    369 InStreams& deserialize(InStreams& in, KeyParameter* param) {
    370     in.elements.read(reinterpret_cast<char*>(&param->tag), sizeof(Tag));
    371     return choose_deserializer<all_tags_t>::deserialize(in, param);
    372 }
    373 
    374 std::istream& deserialize(std::istream& in, std::vector<KeyParameter>* params) {
    375     uint32_t indirect_size = 0;
    376     in.read(reinterpret_cast<char*>(&indirect_size), sizeof(uint32_t));
    377     std::string indirect_buffer(indirect_size, '\0');
    378     if (indirect_buffer.size() != indirect_size) {
    379         in.setstate(std::ios_base::badbit);
    380         return in;
    381     }
    382     in.read(&indirect_buffer[0], indirect_buffer.size());
    383 
    384     uint32_t element_count = 0;
    385     in.read(reinterpret_cast<char*>(&element_count), sizeof(uint32_t));
    386     uint32_t elements_size = 0;
    387     in.read(reinterpret_cast<char*>(&elements_size), sizeof(uint32_t));
    388 
    389     std::string elements_buffer(elements_size, '\0');
    390     if (elements_buffer.size() != elements_size) {
    391         in.setstate(std::ios_base::badbit);
    392         return in;
    393     }
    394     in.read(&elements_buffer[0], elements_buffer.size());
    395 
    396     if (in.bad()) return in;
    397 
    398     // TODO write one-shot stream buffer to avoid copying here
    399     std::stringstream indirect(indirect_buffer);
    400     std::stringstream elements(elements_buffer);
    401     InStreams streams = {indirect, elements};
    402 
    403     params->resize(element_count);
    404 
    405     for (uint32_t i = 0; i < element_count; ++i) {
    406         deserialize(streams, &(*params)[i]);
    407     }
    408     return in;
    409 }
    410 
    411 void AuthorizationSet::Serialize(std::ostream* out) const {
    412     serialize(*out, data_);
    413 }
    414 
    415 void AuthorizationSet::Deserialize(std::istream* in) {
    416     deserialize(*in, &data_);
    417 }
    418 
    419 }  // namespace V3_0
    420 }  // namespace keymaster
    421 }  // namespace hardware
    422 }  // namespace android
    423