Home | History | Annotate | Download | only in strings
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/lib/strings/base64.h"
     17 
     18 #include <cstring>
     19 #include <memory>
     20 #include "tensorflow/core/lib/core/errors.h"
     21 
     22 namespace tensorflow {
     23 namespace {
     24 // This array must have signed type.
     25 // clang-format off
     26 constexpr int8 kBase64Bytes[128] = {
     27      -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
     28      -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
     29      -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
     30      -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,  0x3E,  -1,   -1,
     31     0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D,  -1,   -1,
     32      -1,   -1,   -1,   -1,   -1,  0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06,
     33     0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12,
     34     0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19,  -1,   -1,   -1,   -1,  0x3F,
     35      -1,  0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24,
     36     0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30,
     37     0x31, 0x32, 0x33,  -1,   -1,   -1,   -1,   -1};
     38 // clang-format on
     39 
     40 constexpr char kBase64UrlSafeChars[65] =
     41     "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
     42 
     43 constexpr char kPadChar = '=';
     44 
     45 // Converts a char (8 bits) into a 6-bit value for decoding. If the input char
     46 // is invalid for base64 encoding, the return value has at least its upper 25
     47 // bits set.
     48 inline uint32 Convert(char x) {
     49   // If x < 128, then we look up x in the table. If x is valid, then the table
     50   // will have a value <= 0x3F, otherwise the table will have -1. If x >= 128,
     51   // we still do some table lookup, but the value is ignored since we explicitly
     52   // set the high bit of y to 1. Either way, y is negative (high bit set) in
     53   // case of error.
     54   const int8 y = kBase64Bytes[x & 0x7F] | (x & 0x80);
     55   // Casting from int8 to int32 preserves sign by sign extension. If y was
     56   // negative, at least its 25 high bits of the return value are set.
     57   const int32 z = static_cast<int32>(y);
     58   return static_cast<uint32>(z);
     59 }
     60 
     61 Status DecodeThreeChars(const char* codes, char* result) {
     62   const uint32 packed = (Convert(codes[0]) << 18) | (Convert(codes[1]) << 12) |
     63                         (Convert(codes[2]) << 6) | (Convert(codes[3]));
     64   // Convert() return value has upper 25 bits set if input is invalid.
     65   // Therefore `packed` has high bits set iff at least one of code is invalid.
     66   if (TF_PREDICT_FALSE((packed & 0xFF000000) != 0)) {
     67     return errors::InvalidArgument("Invalid character found in base64.");
     68   }
     69   result[0] = static_cast<char>(packed >> 16);
     70   result[1] = static_cast<char>(packed >> 8);
     71   result[2] = static_cast<char>(packed);
     72   return Status::OK();
     73 }
     74 }  // namespace
     75 
     76 Status Base64Decode(StringPiece data, string* decoded) {
     77   if (decoded == nullptr) {
     78     return errors::Internal("'decoded' cannot be nullptr.");
     79   }
     80 
     81   if (data.empty()) {
     82     decoded->clear();
     83     return Status::OK();
     84   }
     85 
     86   // This decoding procedure will write 3 * ceil(data.size() / 4) bytes to be
     87   // output buffer, then truncate if necessary. Therefore we must overestimate
     88   // and allocate sufficient amount. Currently max_decoded_size may overestimate
     89   // by up to 3 bytes.
     90   const size_t max_decoded_size = 3 * (data.size() / 4) + 3;
     91   std::unique_ptr<char[]> buffer(new char[max_decoded_size]);
     92   char* current = buffer.get();
     93   if (current == nullptr) {
     94     return errors::ResourceExhausted(
     95         "Failed to allocate buffer for decoded string.");
     96   }
     97 
     98   const char* b64 = data.data();
     99   const char* end = data.data() + data.size();
    100 
    101   while (end - b64 > 4) {
    102     TF_RETURN_IF_ERROR(DecodeThreeChars(b64, current));
    103     b64 += 4;
    104     current += 3;
    105   }
    106 
    107   if (end - b64 == 4) {
    108     // The data length is a multiple of 4. Check for padding.
    109     // Base64 cannot have more than 2 paddings.
    110     if (b64[2] == kPadChar && b64[3] == kPadChar) {
    111       end -= 2;
    112     }
    113     if (b64[2] != kPadChar && b64[3] == kPadChar) {
    114       end -= 1;
    115     }
    116   }
    117 
    118   const int remain = static_cast<int>(end - b64);
    119   if (TF_PREDICT_FALSE(remain == 1)) {
    120     // We may check this condition early by checking data.size() % 4 == 1.
    121     return errors::InvalidArgument(
    122         "Base64 string length cannot be 1 modulo 4.");
    123   }
    124 
    125   // A valid base64 character will replace paddings, if any.
    126   char tail[4] = {kBase64UrlSafeChars[0], kBase64UrlSafeChars[0],
    127                   kBase64UrlSafeChars[0], kBase64UrlSafeChars[0]};
    128   // Copy tail of the input into the array, then decode.
    129   std::memcpy(tail, b64, remain * sizeof(*b64));
    130   TF_RETURN_IF_ERROR(DecodeThreeChars(tail, current));
    131   // We know how many parsed characters are valid.
    132   current += remain - 1;
    133 
    134   decoded->assign(buffer.get(), current - buffer.get());
    135   return Status::OK();
    136 }
    137 
    138 Status Base64Encode(StringPiece source, string* encoded) {
    139   return Base64Encode(source, false, encoded);
    140 }
    141 
    142 Status Base64Encode(StringPiece source, bool with_padding, string* encoded) {
    143   const char* const base64_chars = kBase64UrlSafeChars;
    144   if (encoded == nullptr) {
    145     return errors::Internal("'encoded' cannot be nullptr.");
    146   }
    147 
    148   // max_encoded_size may overestimate by up to 4 bytes.
    149   const size_t max_encoded_size = 4 * (source.size() / 3) + 4;
    150   std::unique_ptr<char[]> buffer(new char[max_encoded_size]);
    151   char* current = buffer.get();
    152   if (current == nullptr) {
    153     return errors::ResourceExhausted(
    154         "Failed to allocate buffer for encoded string.");
    155   }
    156 
    157   const char* data = source.data();
    158   const char* const end = source.data() + source.size();
    159 
    160   // Encode each block.
    161   while (end - data >= 3) {
    162     *current++ = base64_chars[(data[0] >> 2) & 0x3F];
    163     *current++ =
    164         base64_chars[((data[0] & 0x03) << 4) | ((data[1] >> 4) & 0x0F)];
    165     *current++ =
    166         base64_chars[((data[1] & 0x0F) << 2) | ((data[2] >> 6) & 0x03)];
    167     *current++ = base64_chars[data[2] & 0x3F];
    168 
    169     data += 3;
    170   }
    171 
    172   // Take care of the tail.
    173   if (end - data == 2) {
    174     *current++ = base64_chars[(data[0] >> 2) & 0x3F];
    175     *current++ =
    176         base64_chars[((data[0] & 0x03) << 4) | ((data[1] >> 4) & 0x0F)];
    177     *current++ = base64_chars[(data[1] & 0x0F) << 2];
    178     if (with_padding) {
    179       *current++ = kPadChar;
    180     }
    181   } else if (end - data == 1) {
    182     *current++ = base64_chars[(data[0] >> 2) & 0x3F];
    183     *current++ = base64_chars[(data[0] & 0x03) << 4];
    184     if (with_padding) {
    185       *current++ = kPadChar;
    186       *current++ = kPadChar;
    187     }
    188   }
    189 
    190   encoded->assign(buffer.get(), current - buffer.get());
    191   return Status::OK();
    192 }
    193 
    194 }  // namespace tensorflow
    195