1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/lib/strings/base64.h" 17 18 #include <cstring> 19 #include <memory> 20 #include "tensorflow/core/lib/core/errors.h" 21 22 namespace tensorflow { 23 namespace { 24 // This array must have signed type. 25 // clang-format off 26 constexpr int8 kBase64Bytes[128] = { 27 -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 28 -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 29 -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 30 -1, -1, -1, -1, -1, -1, -1, -1, -1, 0x3E, -1, -1, 31 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, -1, -1, 32 -1, -1, -1, -1, -1, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 33 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 34 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, -1, -1, -1, -1, 0x3F, 35 -1, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24, 36 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, 37 0x31, 0x32, 0x33, -1, -1, -1, -1, -1}; 38 // clang-format on 39 40 constexpr char kBase64UrlSafeChars[65] = 41 "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"; 42 43 constexpr char kPadChar = '='; 44 45 // Converts a char (8 bits) into a 6-bit value for decoding. If the input char 46 // is invalid for base64 encoding, the return value has at least its upper 25 47 // bits set. 48 inline uint32 Convert(char x) { 49 // If x < 128, then we look up x in the table. If x is valid, then the table 50 // will have a value <= 0x3F, otherwise the table will have -1. If x >= 128, 51 // we still do some table lookup, but the value is ignored since we explicitly 52 // set the high bit of y to 1. Either way, y is negative (high bit set) in 53 // case of error. 54 const int8 y = kBase64Bytes[x & 0x7F] | (x & 0x80); 55 // Casting from int8 to int32 preserves sign by sign extension. If y was 56 // negative, at least its 25 high bits of the return value are set. 57 const int32 z = static_cast<int32>(y); 58 return static_cast<uint32>(z); 59 } 60 61 Status DecodeThreeChars(const char* codes, char* result) { 62 const uint32 packed = (Convert(codes[0]) << 18) | (Convert(codes[1]) << 12) | 63 (Convert(codes[2]) << 6) | (Convert(codes[3])); 64 // Convert() return value has upper 25 bits set if input is invalid. 65 // Therefore `packed` has high bits set iff at least one of code is invalid. 66 if (TF_PREDICT_FALSE((packed & 0xFF000000) != 0)) { 67 return errors::InvalidArgument("Invalid character found in base64."); 68 } 69 result[0] = static_cast<char>(packed >> 16); 70 result[1] = static_cast<char>(packed >> 8); 71 result[2] = static_cast<char>(packed); 72 return Status::OK(); 73 } 74 } // namespace 75 76 Status Base64Decode(StringPiece data, string* decoded) { 77 if (decoded == nullptr) { 78 return errors::Internal("'decoded' cannot be nullptr."); 79 } 80 81 if (data.empty()) { 82 decoded->clear(); 83 return Status::OK(); 84 } 85 86 // This decoding procedure will write 3 * ceil(data.size() / 4) bytes to be 87 // output buffer, then truncate if necessary. Therefore we must overestimate 88 // and allocate sufficient amount. Currently max_decoded_size may overestimate 89 // by up to 3 bytes. 90 const size_t max_decoded_size = 3 * (data.size() / 4) + 3; 91 std::unique_ptr<char[]> buffer(new char[max_decoded_size]); 92 char* current = buffer.get(); 93 if (current == nullptr) { 94 return errors::ResourceExhausted( 95 "Failed to allocate buffer for decoded string."); 96 } 97 98 const char* b64 = data.data(); 99 const char* end = data.data() + data.size(); 100 101 while (end - b64 > 4) { 102 TF_RETURN_IF_ERROR(DecodeThreeChars(b64, current)); 103 b64 += 4; 104 current += 3; 105 } 106 107 if (end - b64 == 4) { 108 // The data length is a multiple of 4. Check for padding. 109 // Base64 cannot have more than 2 paddings. 110 if (b64[2] == kPadChar && b64[3] == kPadChar) { 111 end -= 2; 112 } 113 if (b64[2] != kPadChar && b64[3] == kPadChar) { 114 end -= 1; 115 } 116 } 117 118 const int remain = static_cast<int>(end - b64); 119 if (TF_PREDICT_FALSE(remain == 1)) { 120 // We may check this condition early by checking data.size() % 4 == 1. 121 return errors::InvalidArgument( 122 "Base64 string length cannot be 1 modulo 4."); 123 } 124 125 // A valid base64 character will replace paddings, if any. 126 char tail[4] = {kBase64UrlSafeChars[0], kBase64UrlSafeChars[0], 127 kBase64UrlSafeChars[0], kBase64UrlSafeChars[0]}; 128 // Copy tail of the input into the array, then decode. 129 std::memcpy(tail, b64, remain * sizeof(*b64)); 130 TF_RETURN_IF_ERROR(DecodeThreeChars(tail, current)); 131 // We know how many parsed characters are valid. 132 current += remain - 1; 133 134 decoded->assign(buffer.get(), current - buffer.get()); 135 return Status::OK(); 136 } 137 138 Status Base64Encode(StringPiece source, string* encoded) { 139 return Base64Encode(source, false, encoded); 140 } 141 142 Status Base64Encode(StringPiece source, bool with_padding, string* encoded) { 143 const char* const base64_chars = kBase64UrlSafeChars; 144 if (encoded == nullptr) { 145 return errors::Internal("'encoded' cannot be nullptr."); 146 } 147 148 // max_encoded_size may overestimate by up to 4 bytes. 149 const size_t max_encoded_size = 4 * (source.size() / 3) + 4; 150 std::unique_ptr<char[]> buffer(new char[max_encoded_size]); 151 char* current = buffer.get(); 152 if (current == nullptr) { 153 return errors::ResourceExhausted( 154 "Failed to allocate buffer for encoded string."); 155 } 156 157 const char* data = source.data(); 158 const char* const end = source.data() + source.size(); 159 160 // Encode each block. 161 while (end - data >= 3) { 162 *current++ = base64_chars[(data[0] >> 2) & 0x3F]; 163 *current++ = 164 base64_chars[((data[0] & 0x03) << 4) | ((data[1] >> 4) & 0x0F)]; 165 *current++ = 166 base64_chars[((data[1] & 0x0F) << 2) | ((data[2] >> 6) & 0x03)]; 167 *current++ = base64_chars[data[2] & 0x3F]; 168 169 data += 3; 170 } 171 172 // Take care of the tail. 173 if (end - data == 2) { 174 *current++ = base64_chars[(data[0] >> 2) & 0x3F]; 175 *current++ = 176 base64_chars[((data[0] & 0x03) << 4) | ((data[1] >> 4) & 0x0F)]; 177 *current++ = base64_chars[(data[1] & 0x0F) << 2]; 178 if (with_padding) { 179 *current++ = kPadChar; 180 } 181 } else if (end - data == 1) { 182 *current++ = base64_chars[(data[0] >> 2) & 0x3F]; 183 *current++ = base64_chars[(data[0] & 0x03) << 4]; 184 if (with_padding) { 185 *current++ = kPadChar; 186 *current++ = kPadChar; 187 } 188 } 189 190 encoded->assign(buffer.get(), current - buffer.get()); 191 return Status::OK(); 192 } 193 194 } // namespace tensorflow 195