Home | History | Annotate | Download | only in io
      1 /*
      2  * Copyright (C) 2015 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License
     15  */
     16 
     17 package libcore.io;
     18 
     19 import java.io.ByteArrayOutputStream;
     20 import java.nio.charset.StandardCharsets;
     21 
     22 /**
     23  * Perform encoding and decoding of Base64 byte arrays as described in
     24  * http://www.ietf.org/rfc/rfc2045.txt
     25  */
     26 public final class Base64 {
     27     private static final byte[] BASE_64_ALPHABET = initializeBase64Alphabet();
     28 
     29     private static byte[] initializeBase64Alphabet() {
     30         return "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
     31                     .getBytes(StandardCharsets.US_ASCII);
     32     }
     33 
     34     // Bit masks for the 4 output 6-bit values from 3 input bytes.
     35     private static final int FIRST_OUTPUT_BYTE_MASK = 0x3f << 18;
     36     private static final int SECOND_OUTPUT_BYTE_MASK = 0x3f << 12;
     37     private static final int THIRD_OUTPUT_BYTE_MASK = 0x3f << 6;
     38     private static final int FOURTH_OUTPUT_BYTE_MASK = 0x3f;
     39 
     40     private Base64() {}
     41 
     42     public static String encode(byte[] in) {
     43         int len = in.length;
     44         int outputLen = computeEncodingOutputLen(len);
     45         byte[] output = new byte[outputLen];
     46 
     47         int outputIndex = 0;
     48         for (int i = 0; i < len; i += 3) {
     49             // Only a "triplet" if there are there are at least three remaining bytes
     50             // in the input...
     51             // Mask with 0xff to avoid signed extension.
     52             int byteTripletAsInt = in[i] & 0xff;
     53             if (i + 1 < len) {
     54                 // Add second byte to the triplet.
     55                 byteTripletAsInt <<= 8;
     56                 byteTripletAsInt |= in[i + 1] & 0xff;
     57                 if (i + 2 < len) {
     58                     byteTripletAsInt <<= 8;
     59                     byteTripletAsInt |= in[i + 2] & 0xff;
     60                 } else {
     61                     // Insert 2 zero bits as to make output 18 bits long.
     62                     byteTripletAsInt <<= 2;
     63                 }
     64             } else {
     65                 // Insert 4 zero bits as to make output 12 bits long.
     66                 byteTripletAsInt <<= 4;
     67             }
     68 
     69             if (i + 2 < len) {
     70                 // The int may have up to 24 non-zero bits.
     71                 output[outputIndex++] = BASE_64_ALPHABET[
     72                         (byteTripletAsInt & FIRST_OUTPUT_BYTE_MASK) >>> 18];
     73             }
     74             if (i + 1 < len) {
     75                 // The int may have up to 18 non-zero bits.
     76                 output[outputIndex++] = BASE_64_ALPHABET[
     77                         (byteTripletAsInt & SECOND_OUTPUT_BYTE_MASK) >>> 12];
     78             }
     79             output[outputIndex++] = BASE_64_ALPHABET[
     80                     (byteTripletAsInt & THIRD_OUTPUT_BYTE_MASK) >>> 6];
     81             output[outputIndex++] = BASE_64_ALPHABET[
     82                     byteTripletAsInt & FOURTH_OUTPUT_BYTE_MASK];
     83         }
     84 
     85         int inLengthMod3 = len % 3;
     86         // Add padding as per the spec.
     87         if (inLengthMod3 > 0) {
     88             output[outputIndex++] = '=';
     89             if (inLengthMod3 == 1) {
     90                 output[outputIndex++] = '=';
     91             }
     92         }
     93 
     94         return new String(output, StandardCharsets.US_ASCII);
     95     }
     96 
     97     private static int computeEncodingOutputLen(int inLength) {
     98         int inLengthMod3 = inLength % 3;
     99         int outputLen = (inLength / 3) * 4;
    100         if (inLengthMod3 == 2) {
    101             // Need 3 6-bit characters as to express the last 16 bits, plus 1 padding.
    102             outputLen += 4;
    103         } else if (inLengthMod3 == 1) {
    104             // Need 2 6-bit characters as to express the last 8 bits, plus 2 padding.
    105             outputLen += 4;
    106         }
    107         return outputLen;
    108     }
    109 
    110     public static byte[] decode(byte[] in) {
    111         return decode(in, in.length);
    112     }
    113 
    114     /** Decodes the input from position 0 (inclusive) to len (exclusive). */
    115     public static byte[] decode(byte[] in, int len) {
    116         final int inLength = Math.min(in.length, len);
    117         // Overestimating 3 bytes per each 4 blocks of input (plus a possibly incomplete one).
    118         ByteArrayOutputStream output = new ByteArrayOutputStream((inLength / 4) * 3 + 3);
    119         // Position in the input. Use an array so we can pass it to {@code getNextByte}.
    120         int[] pos = new int[1];
    121 
    122         try {
    123             while (pos[0] < inLength) {
    124                 int byteTripletAsInt = 0;
    125 
    126                 // j is the index in a 4-tuple of 6-bit characters where are trying to read from the
    127                 // input.
    128                 for (int j = 0; j < 4; j++) {
    129                     byte c = getNextByte(in, pos, inLength);
    130                     if (c == END_OF_INPUT || c == PAD_AS_BYTE) {
    131                         // Padding or end of file...
    132                         switch (j) {
    133                             case 0:
    134                             case 1:
    135                                 return (c == END_OF_INPUT) ? output.toByteArray() : null;
    136                             case 2:
    137                                 // The input is over with two 6-bit characters: a single byte padded
    138                                 // with 4 extra 0's.
    139 
    140                                 if (c == END_OF_INPUT) {
    141                                     // Do not consider the block, since padding is not present.
    142                                     return checkNoTrailingAndReturn(output, in, pos[0], inLength);
    143                                 }
    144                                 // We are at a pad character, consume and look for the second one.
    145                                 pos[0]++;
    146                                 c = getNextByte(in, pos, inLength);
    147                                 if (c == END_OF_INPUT) {
    148                                     // Do not consider the block, since padding is not present.
    149                                     return checkNoTrailingAndReturn(output, in, pos[0], inLength);
    150                                 }
    151                                 if (c == PAD_AS_BYTE) {
    152                                     byteTripletAsInt >>= 4;
    153                                     output.write(byteTripletAsInt);
    154                                     return checkNoTrailingAndReturn(output, in, pos[0], inLength);
    155                                 }
    156                                 // Something other than pad and non-alphabet characters, illegal.
    157                                 return null;
    158 
    159 
    160                             case 3:
    161                                 // The input is over with three 6-bit characters: two bytes padded
    162                                 // with 2 extra 0's.
    163                                 if (c == PAD_AS_BYTE) {
    164                                     // Consider the block only if padding is present.
    165                                     byteTripletAsInt >>= 2;
    166                                     output.write(byteTripletAsInt >> 8);
    167                                     output.write(byteTripletAsInt & 0xff);
    168                                 }
    169                                 return checkNoTrailingAndReturn(output, in, pos[0], inLength);
    170                         }
    171                     } else {
    172                         byteTripletAsInt <<= 6;
    173                         byteTripletAsInt += (c & 0xff);
    174                         pos[0]++;
    175                     }
    176                 }
    177                 // We have four 6-bit characters: output the corresponding 3 bytes
    178                 output.write(byteTripletAsInt >> 16);
    179                 output.write((byteTripletAsInt >> 8) & 0xff);
    180                 output.write(byteTripletAsInt & 0xff);
    181             }
    182             return checkNoTrailingAndReturn(output, in, pos[0], inLength);
    183         } catch (InvalidBase64ByteException e) {
    184             return null;
    185         }
    186     }
    187 
    188     /**
    189      * On decoding, an illegal character always return null.
    190      *
    191      * Using this exception to avoid "if" checks every time.
    192      */
    193 
    194     private static class InvalidBase64ByteException extends Exception { }
    195 
    196     /**
    197      * Obtain the numeric value corresponding to the next relevant byte in the input.
    198      *
    199      * Calculates the numeric value (6-bit, 0 <= x <= 63) of the next Base64 encoded byte in
    200      * {@code in} at or after {@code pos[0]} and before {@code inLength}. Returns
    201      * {@link #WHITESPACE_AS_BYTE}, {@link #PAD_AS_BYTE}, {@link #END_OF_INPUT} or the 6-bit value.
    202      * {@code pos[0]} is updated as a side effect of this method.
    203      */
    204     private static byte getNextByte(byte[] in, int[] pos, int inLength)
    205             throws InvalidBase64ByteException {
    206         // Ignore all whitespace.
    207         while (pos[0] < inLength) {
    208             byte c = base64AlphabetToNumericalValue(in[pos[0]]);
    209             if (c != WHITESPACE_AS_BYTE) {
    210                 return c;
    211             }
    212             pos[0]++;
    213         }
    214         return END_OF_INPUT;
    215     }
    216 
    217     /**
    218      * Check that there are no invalid trailing characters (ie, other then whitespace and padding)
    219      *
    220      * Returns {@code output} as a byte array in case of success, {@code null} in case of invalid
    221      * characters.
    222      */
    223     private static byte[] checkNoTrailingAndReturn(
    224             ByteArrayOutputStream output, byte[] in, int i, int inLength)
    225                     throws InvalidBase64ByteException{
    226         while (i < inLength) {
    227             byte c = base64AlphabetToNumericalValue(in[i]);
    228             if (c != WHITESPACE_AS_BYTE && c != PAD_AS_BYTE) {
    229                 return null;
    230             }
    231             i++;
    232         }
    233         return output.toByteArray();
    234     }
    235 
    236     private static final byte PAD_AS_BYTE = -1;
    237     private static final byte WHITESPACE_AS_BYTE = -2;
    238     private static final byte END_OF_INPUT = -3;
    239     private static byte base64AlphabetToNumericalValue(byte c) throws InvalidBase64ByteException {
    240         if ('A' <= c && c <= 'Z') {
    241             return (byte) (c - 'A');
    242         }
    243         if ('a' <= c && c <= 'z') {
    244             return (byte) (c - 'a' + 26);
    245         }
    246         if ('0' <= c && c <= '9') {
    247             return (byte) (c - '0' + 52);
    248         }
    249         if (c == '+') {
    250             return (byte) 62;
    251         }
    252         if (c == '/') {
    253             return (byte) 63;
    254         }
    255         if (c == '=') {
    256             return PAD_AS_BYTE;
    257         }
    258         if (c == ' ' || c == '\t' || c == '\r' || c == '\n') {
    259             return WHITESPACE_AS_BYTE;
    260         }
    261         throw new InvalidBase64ByteException();
    262     }
    263 }
    264