Home | History | Annotate | Download | only in net
      1 //
      2 // Copyright (C) 2011 The Android Open Source Project
      3 //
      4 // Licensed under the Apache License, Version 2.0 (the "License");
      5 // you may not use this file except in compliance with the License.
      6 // You may obtain a copy of the License at
      7 //
      8 //      http://www.apache.org/licenses/LICENSE-2.0
      9 //
     10 // Unless required by applicable law or agreed to in writing, software
     11 // distributed under the License is distributed on an "AS IS" BASIS,
     12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 // See the License for the specific language governing permissions and
     14 // limitations under the License.
     15 //
     16 
     17 #include "shill/net/byte_string.h"
     18 
     19 #include <netinet/in.h>
     20 #include <string.h>
     21 
     22 #include <algorithm>
     23 
     24 #include <base/strings/string_number_conversions.h>
     25 
     26 using std::min;
     27 using std::string;
     28 using std::vector;
     29 
     30 namespace shill {
     31 
     32 ByteString::ByteString(const ByteString& b) {
     33   data_ = b.data_;
     34 }
     35 
     36 ByteString& ByteString::operator=(const ByteString& b) {
     37   data_ = b.data_;
     38   return *this;
     39 }
     40 
     41 unsigned char* ByteString::GetData() {
     42   return (GetLength() == 0) ? nullptr : &data_.front();
     43 }
     44 
     45 const unsigned char* ByteString::GetConstData() const {
     46   return (GetLength() == 0) ? nullptr : &data_.front();
     47 }
     48 
     49 size_t ByteString::GetLength() const {
     50   return data_.size();
     51 }
     52 
     53 ByteString ByteString::GetSubstring(size_t offset, size_t length) const {
     54   if (offset > GetLength()) {
     55     offset = GetLength();
     56   }
     57   if (length > GetLength() - offset) {
     58     length = GetLength() - offset;
     59   }
     60   return ByteString(GetConstData() + offset, length);
     61 }
     62 
     63 // static
     64 ByteString ByteString::CreateFromCPUUInt32(uint32_t val) {
     65   return ByteString(reinterpret_cast<unsigned char*>(&val), sizeof(val));
     66 }
     67 
     68 // static
     69 ByteString ByteString::CreateFromNetUInt32(uint32_t val) {
     70   return CreateFromCPUUInt32(ntohl(val));
     71 }
     72 
     73 // static
     74 ByteString ByteString::CreateFromHexString(const string& hex_string) {
     75   vector<uint8_t> bytes;
     76   if (!base::HexStringToBytes(hex_string, &bytes)) {
     77     return ByteString();
     78   }
     79   return ByteString(&bytes.front(), bytes.size());
     80 }
     81 
     82 bool ByteString::ConvertToCPUUInt32(uint32_t* val) const {
     83   if (val == nullptr || GetLength() != sizeof(*val)) {
     84     return false;
     85   }
     86   memcpy(val, GetConstData(), sizeof(*val));
     87 
     88   return true;
     89 }
     90 
     91 bool ByteString::ConvertToNetUInt32(uint32_t* val) const {
     92   if (!ConvertToCPUUInt32(val)) {
     93     return false;
     94   }
     95   *val = ntohl(*val);
     96   return true;
     97 }
     98 
     99 template <typename T>
    100 bool ByteString::ConvertByteOrderAsUIntArray(T (*converter)(T)) {
    101   size_t length = GetLength();
    102   if ((length % sizeof(T)) != 0) {
    103     return false;
    104   }
    105   for (auto i = data_.begin(); i != data_.end(); i += sizeof(T)) {
    106     // Take care of word alignment.
    107     T val;
    108     memcpy(&val, &(*i), sizeof(T));
    109     val = converter(val);
    110     memcpy(&(*i), &val, sizeof(T));
    111   }
    112   return true;
    113 }
    114 
    115 bool ByteString::ConvertFromNetToCPUUInt32Array() {
    116   return ConvertByteOrderAsUIntArray(ntohl);
    117 }
    118 
    119 bool ByteString::ConvertFromCPUToNetUInt32Array() {
    120   return ConvertByteOrderAsUIntArray(htonl);
    121 }
    122 
    123 bool ByteString::IsZero() const {
    124   for (const auto& i : data_) {
    125     if (i != 0) {
    126       return false;
    127     }
    128   }
    129   return true;
    130 }
    131 
    132 bool ByteString::BitwiseAnd(const ByteString& b) {
    133   if (GetLength() != b.GetLength()) {
    134     return false;
    135   }
    136   auto lhs = data_.begin();
    137   for (const auto& rhs : b.data_) {
    138     *lhs++ &= rhs;
    139   }
    140   return true;
    141 }
    142 
    143 bool ByteString::BitwiseOr(const ByteString& b) {
    144   if (GetLength() != b.GetLength()) {
    145     return false;
    146   }
    147   auto lhs = data_.begin();
    148   for (const auto& rhs : b.data_) {
    149     *lhs++ |= rhs;
    150   }
    151   return true;
    152 }
    153 
    154 void ByteString::BitwiseInvert() {
    155   for (auto& i : data_) {
    156     i = ~i;
    157   }
    158 }
    159 
    160 bool ByteString::Equals(const ByteString& b) const {
    161   if (GetLength() != b.GetLength()) {
    162     return false;
    163   }
    164   auto lhs = data_.begin();
    165   for (const auto& rhs : b.data_) {
    166     if (*lhs++ != rhs) {
    167       return false;
    168     }
    169   }
    170   return true;
    171 }
    172 
    173 void ByteString::Append(const ByteString& b) {
    174   data_.insert(data_.end(), b.data_.begin(), b.data_.end());
    175 }
    176 
    177 void ByteString::Clear() {
    178   data_.clear();
    179 }
    180 
    181 void ByteString::Resize(int size) {
    182   data_.resize(size, 0);
    183 }
    184 
    185 string ByteString::HexEncode() const {
    186   return base::HexEncode(GetConstData(), GetLength());
    187 }
    188 
    189 bool ByteString::CopyData(size_t size, void* output) const {
    190   if (output == nullptr || GetLength() < size) {
    191     return false;
    192   }
    193   memcpy(output, GetConstData(), size);
    194   return true;
    195 }
    196 
    197 // static
    198 bool ByteString::IsLessThan(const ByteString& lhs, const ByteString& rhs) {
    199   size_t byte_count = min(lhs.GetLength(), rhs.GetLength());
    200   int result = memcmp(lhs.GetConstData(), rhs.GetConstData(), byte_count);
    201   if (result == 0) {
    202     return lhs.GetLength() < rhs.GetLength();
    203   }
    204   return result < 0;
    205 }
    206 
    207 }  // namespace shill
    208