Home | History | Annotate | Download | only in apf
      1 /*
      2  * Copyright 2016, The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  * http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "apf_interpreter.h"
     18 
     19 #include <string.h> // For memcmp
     20 
     21 #include "apf.h"
     22 
     23 // Return code indicating "packet" should accepted.
     24 #define PASS_PACKET 1
     25 // Return code indicating "packet" should be dropped.
     26 #define DROP_PACKET 0
     27 // Verify an internal condition and accept packet if it fails.
     28 #define ASSERT_RETURN(c) if (!(c)) return PASS_PACKET
     29 // If "c" is of an unsigned type, generate a compile warning that gets promoted to an error.
     30 // This makes bounds checking simpler because ">= 0" can be avoided. Otherwise adding
     31 // superfluous ">= 0" with unsigned expressions generates compile warnings.
     32 #define ENFORCE_UNSIGNED(c) ((c)==(uint32_t)(c))
     33 
     34 /**
     35  * Runs a packet filtering program over a packet.
     36  *
     37  * @param program the program bytecode.
     38  * @param program_len the length of {@code apf_program} in bytes.
     39  * @param packet the packet bytes, starting from the 802.3 header and not
     40  *               including any CRC bytes at the end.
     41  * @param packet_len the length of {@code packet} in bytes.
     42  * @param filter_age the number of seconds since the filter was programmed.
     43  *
     44  * @return non-zero if packet should be passed to AP, zero if
     45  *         packet should be dropped.
     46  */
     47 int accept_packet(const uint8_t* program, uint32_t program_len,
     48                   const uint8_t* packet, uint32_t packet_len,
     49                   uint32_t filter_age) {
     50 // Is offset within program bounds?
     51 #define IN_PROGRAM_BOUNDS(p) (ENFORCE_UNSIGNED(p) && (p) < program_len)
     52 // Is offset within packet bounds?
     53 #define IN_PACKET_BOUNDS(p) (ENFORCE_UNSIGNED(p) && (p) < packet_len)
     54 // Accept packet if not within program bounds
     55 #define ASSERT_IN_PROGRAM_BOUNDS(p) ASSERT_RETURN(IN_PROGRAM_BOUNDS(p))
     56 // Accept packet if not within packet bounds
     57 #define ASSERT_IN_PACKET_BOUNDS(p) ASSERT_RETURN(IN_PACKET_BOUNDS(p))
     58   // Program counter.
     59   uint32_t pc = 0;
     60 // Accept packet if not within program or not ahead of program counter
     61 #define ASSERT_FORWARD_IN_PROGRAM(p) ASSERT_RETURN(IN_PROGRAM_BOUNDS(p) && (p) >= pc)
     62   // Memory slot values.
     63   uint32_t memory[MEMORY_ITEMS] = {};
     64   // Fill in pre-filled memory slot values.
     65   memory[MEMORY_OFFSET_PACKET_SIZE] = packet_len;
     66   memory[MEMORY_OFFSET_FILTER_AGE] = filter_age;
     67   ASSERT_IN_PACKET_BOUNDS(APF_FRAME_HEADER_SIZE);
     68   // Only populate if IP version is IPv4.
     69   if ((packet[APF_FRAME_HEADER_SIZE] & 0xf0) == 0x40) {
     70       memory[MEMORY_OFFSET_IPV4_HEADER_SIZE] = (packet[APF_FRAME_HEADER_SIZE] & 15) * 4;
     71   }
     72   // Register values.
     73   uint32_t registers[2] = {};
     74   // Count of instructions remaining to execute. This is done to ensure an
     75   // upper bound on execution time. It should never be hit and is only for
     76   // safety. Initialize to the number of bytes in the program which is an
     77   // upper bound on the number of instructions in the program.
     78   uint32_t instructions_remaining = program_len;
     79 
     80   do {
     81       if (pc == program_len) {
     82           return PASS_PACKET;
     83       } else if (pc == (program_len + 1)) {
     84           return DROP_PACKET;
     85       }
     86       ASSERT_IN_PROGRAM_BOUNDS(pc);
     87       const uint8_t bytecode = program[pc++];
     88       const uint32_t opcode = EXTRACT_OPCODE(bytecode);
     89       const uint32_t reg_num = EXTRACT_REGISTER(bytecode);
     90 #define REG (registers[reg_num])
     91 #define OTHER_REG (registers[reg_num ^ 1])
     92       // All instructions have immediate fields, so load them now.
     93       const uint32_t len_field = EXTRACT_IMM_LENGTH(bytecode);
     94       uint32_t imm = 0;
     95       int32_t signed_imm = 0;
     96       if (len_field != 0) {
     97           const uint32_t imm_len = 1 << (len_field - 1);
     98           ASSERT_FORWARD_IN_PROGRAM(pc + imm_len - 1);
     99           uint32_t i;
    100           for (i = 0; i < imm_len; i++)
    101               imm = (imm << 8) | program[pc++];
    102           // Sign extend imm into signed_imm.
    103           signed_imm = imm << ((4 - imm_len) * 8);
    104           signed_imm >>= (4 - imm_len) * 8;
    105       }
    106       switch (opcode) {
    107           case LDB_OPCODE:
    108           case LDH_OPCODE:
    109           case LDW_OPCODE:
    110           case LDBX_OPCODE:
    111           case LDHX_OPCODE:
    112           case LDWX_OPCODE: {
    113               uint32_t offs = imm;
    114               if (opcode >= LDBX_OPCODE) {
    115                   // Note: this can overflow and actually decrease offs.
    116                   offs += registers[1];
    117               }
    118               ASSERT_IN_PACKET_BOUNDS(offs);
    119               uint32_t load_size;
    120               switch (opcode) {
    121                   case LDB_OPCODE:
    122                   case LDBX_OPCODE:
    123                     load_size = 1;
    124                     break;
    125                   case LDH_OPCODE:
    126                   case LDHX_OPCODE:
    127                     load_size = 2;
    128                     break;
    129                   case LDW_OPCODE:
    130                   case LDWX_OPCODE:
    131                     load_size = 4;
    132                     break;
    133                   // Immediately enclosing switch statement guarantees
    134                   // opcode cannot be any other value.
    135               }
    136               const uint32_t end_offs = offs + (load_size - 1);
    137               // Catch overflow/wrap-around.
    138               ASSERT_RETURN(end_offs >= offs);
    139               ASSERT_IN_PACKET_BOUNDS(end_offs);
    140               uint32_t val = 0;
    141               while (load_size--)
    142                   val = (val << 8) | packet[offs++];
    143               REG = val;
    144               break;
    145           }
    146           case JMP_OPCODE:
    147               // This can jump backwards. Infinite looping prevented by instructions_remaining.
    148               pc += imm;
    149               break;
    150           case JEQ_OPCODE:
    151           case JNE_OPCODE:
    152           case JGT_OPCODE:
    153           case JLT_OPCODE:
    154           case JSET_OPCODE:
    155           case JNEBS_OPCODE: {
    156               // Load second immediate field.
    157               uint32_t cmp_imm = 0;
    158               if (reg_num == 1) {
    159                   cmp_imm = registers[1];
    160               } else if (len_field != 0) {
    161                   uint32_t cmp_imm_len = 1 << (len_field - 1);
    162                   ASSERT_FORWARD_IN_PROGRAM(pc + cmp_imm_len - 1);
    163                   uint32_t i;
    164                   for (i = 0; i < cmp_imm_len; i++)
    165                       cmp_imm = (cmp_imm << 8) | program[pc++];
    166               }
    167               switch (opcode) {
    168                   case JEQ_OPCODE:
    169                       if (registers[0] == cmp_imm)
    170                           pc += imm;
    171                       break;
    172                   case JNE_OPCODE:
    173                       if (registers[0] != cmp_imm)
    174                           pc += imm;
    175                       break;
    176                   case JGT_OPCODE:
    177                       if (registers[0] > cmp_imm)
    178                           pc += imm;
    179                       break;
    180                   case JLT_OPCODE:
    181                       if (registers[0] < cmp_imm)
    182                           pc += imm;
    183                       break;
    184                   case JSET_OPCODE:
    185                       if (registers[0] & cmp_imm)
    186                           pc += imm;
    187                       break;
    188                   case JNEBS_OPCODE: {
    189                       // cmp_imm is size in bytes of data to compare.
    190                       // pc is offset of program bytes to compare.
    191                       // imm is jump target offset.
    192                       // REG is offset of packet bytes to compare.
    193                       ASSERT_FORWARD_IN_PROGRAM(pc + cmp_imm - 1);
    194                       ASSERT_IN_PACKET_BOUNDS(REG);
    195                       const uint32_t last_packet_offs = REG + cmp_imm - 1;
    196                       ASSERT_RETURN(last_packet_offs >= REG);
    197                       ASSERT_IN_PACKET_BOUNDS(last_packet_offs);
    198                       if (memcmp(program + pc, packet + REG, cmp_imm))
    199                           pc += imm;
    200                       // skip past comparison bytes
    201                       pc += cmp_imm;
    202                       break;
    203                   }
    204               }
    205               break;
    206           }
    207           case ADD_OPCODE:
    208               registers[0] += reg_num ? registers[1] : imm;
    209               break;
    210           case MUL_OPCODE:
    211               registers[0] *= reg_num ? registers[1] : imm;
    212               break;
    213           case DIV_OPCODE: {
    214               const uint32_t div_operand = reg_num ? registers[1] : imm;
    215               ASSERT_RETURN(div_operand);
    216               registers[0] /= div_operand;
    217               break;
    218           }
    219           case AND_OPCODE:
    220               registers[0] &= reg_num ? registers[1] : imm;
    221               break;
    222           case OR_OPCODE:
    223               registers[0] |= reg_num ? registers[1] : imm;
    224               break;
    225           case SH_OPCODE: {
    226               const int32_t shift_val = reg_num ? (int32_t)registers[1] : signed_imm;
    227               if (shift_val > 0)
    228                   registers[0] <<= shift_val;
    229               else
    230                   registers[0] >>= -shift_val;
    231               break;
    232           }
    233           case LI_OPCODE:
    234               REG = signed_imm;
    235               break;
    236           case EXT_OPCODE:
    237               if (
    238 // If LDM_EXT_OPCODE is 0 and imm is compared with it, a compiler error will result,
    239 // instead just enforce that imm is unsigned (so it's always greater or equal to 0).
    240 #if LDM_EXT_OPCODE == 0
    241                   ENFORCE_UNSIGNED(imm) &&
    242 #else
    243                   imm >= LDM_EXT_OPCODE &&
    244 #endif
    245                   imm < (LDM_EXT_OPCODE + MEMORY_ITEMS)) {
    246                 REG = memory[imm - LDM_EXT_OPCODE];
    247               } else if (imm >= STM_EXT_OPCODE && imm < (STM_EXT_OPCODE + MEMORY_ITEMS)) {
    248                 memory[imm - STM_EXT_OPCODE] = REG;
    249               } else switch (imm) {
    250                   case NOT_EXT_OPCODE:
    251                     REG = ~REG;
    252                     break;
    253                   case NEG_EXT_OPCODE:
    254                     REG = -REG;
    255                     break;
    256                   case SWAP_EXT_OPCODE: {
    257                     uint32_t tmp = REG;
    258                     REG = OTHER_REG;
    259                     OTHER_REG = tmp;
    260                     break;
    261                   }
    262                   case MOV_EXT_OPCODE:
    263                     REG = OTHER_REG;
    264                     break;
    265                   // Unknown extended opcode
    266                   default:
    267                     // Bail out
    268                     return PASS_PACKET;
    269               }
    270               break;
    271           // Unknown opcode
    272           default:
    273               // Bail out
    274               return PASS_PACKET;
    275       }
    276   } while (instructions_remaining--);
    277   return PASS_PACKET;
    278 }
    279