1 /* 2 * Copyright 2016, The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "apf_interpreter.h" 18 19 #include <string.h> // For memcmp 20 21 #include "apf.h" 22 23 // Return code indicating "packet" should accepted. 24 #define PASS_PACKET 1 25 // Return code indicating "packet" should be dropped. 26 #define DROP_PACKET 0 27 // Verify an internal condition and accept packet if it fails. 28 #define ASSERT_RETURN(c) if (!(c)) return PASS_PACKET 29 // If "c" is of an unsigned type, generate a compile warning that gets promoted to an error. 30 // This makes bounds checking simpler because ">= 0" can be avoided. Otherwise adding 31 // superfluous ">= 0" with unsigned expressions generates compile warnings. 32 #define ENFORCE_UNSIGNED(c) ((c)==(uint32_t)(c)) 33 34 /** 35 * Runs a packet filtering program over a packet. 36 * 37 * @param program the program bytecode. 38 * @param program_len the length of {@code apf_program} in bytes. 39 * @param packet the packet bytes, starting from the 802.3 header and not 40 * including any CRC bytes at the end. 41 * @param packet_len the length of {@code packet} in bytes. 42 * @param filter_age the number of seconds since the filter was programmed. 43 * 44 * @return non-zero if packet should be passed to AP, zero if 45 * packet should be dropped. 46 */ 47 int accept_packet(const uint8_t* program, uint32_t program_len, 48 const uint8_t* packet, uint32_t packet_len, 49 uint32_t filter_age) { 50 // Is offset within program bounds? 51 #define IN_PROGRAM_BOUNDS(p) (ENFORCE_UNSIGNED(p) && (p) < program_len) 52 // Is offset within packet bounds? 53 #define IN_PACKET_BOUNDS(p) (ENFORCE_UNSIGNED(p) && (p) < packet_len) 54 // Accept packet if not within program bounds 55 #define ASSERT_IN_PROGRAM_BOUNDS(p) ASSERT_RETURN(IN_PROGRAM_BOUNDS(p)) 56 // Accept packet if not within packet bounds 57 #define ASSERT_IN_PACKET_BOUNDS(p) ASSERT_RETURN(IN_PACKET_BOUNDS(p)) 58 // Program counter. 59 uint32_t pc = 0; 60 // Accept packet if not within program or not ahead of program counter 61 #define ASSERT_FORWARD_IN_PROGRAM(p) ASSERT_RETURN(IN_PROGRAM_BOUNDS(p) && (p) >= pc) 62 // Memory slot values. 63 uint32_t memory[MEMORY_ITEMS] = {}; 64 // Fill in pre-filled memory slot values. 65 memory[MEMORY_OFFSET_PACKET_SIZE] = packet_len; 66 memory[MEMORY_OFFSET_FILTER_AGE] = filter_age; 67 ASSERT_IN_PACKET_BOUNDS(APF_FRAME_HEADER_SIZE); 68 // Only populate if IP version is IPv4. 69 if ((packet[APF_FRAME_HEADER_SIZE] & 0xf0) == 0x40) { 70 memory[MEMORY_OFFSET_IPV4_HEADER_SIZE] = (packet[APF_FRAME_HEADER_SIZE] & 15) * 4; 71 } 72 // Register values. 73 uint32_t registers[2] = {}; 74 // Count of instructions remaining to execute. This is done to ensure an 75 // upper bound on execution time. It should never be hit and is only for 76 // safety. Initialize to the number of bytes in the program which is an 77 // upper bound on the number of instructions in the program. 78 uint32_t instructions_remaining = program_len; 79 80 do { 81 if (pc == program_len) { 82 return PASS_PACKET; 83 } else if (pc == (program_len + 1)) { 84 return DROP_PACKET; 85 } 86 ASSERT_IN_PROGRAM_BOUNDS(pc); 87 const uint8_t bytecode = program[pc++]; 88 const uint32_t opcode = EXTRACT_OPCODE(bytecode); 89 const uint32_t reg_num = EXTRACT_REGISTER(bytecode); 90 #define REG (registers[reg_num]) 91 #define OTHER_REG (registers[reg_num ^ 1]) 92 // All instructions have immediate fields, so load them now. 93 const uint32_t len_field = EXTRACT_IMM_LENGTH(bytecode); 94 uint32_t imm = 0; 95 int32_t signed_imm = 0; 96 if (len_field != 0) { 97 const uint32_t imm_len = 1 << (len_field - 1); 98 ASSERT_FORWARD_IN_PROGRAM(pc + imm_len - 1); 99 uint32_t i; 100 for (i = 0; i < imm_len; i++) 101 imm = (imm << 8) | program[pc++]; 102 // Sign extend imm into signed_imm. 103 signed_imm = imm << ((4 - imm_len) * 8); 104 signed_imm >>= (4 - imm_len) * 8; 105 } 106 switch (opcode) { 107 case LDB_OPCODE: 108 case LDH_OPCODE: 109 case LDW_OPCODE: 110 case LDBX_OPCODE: 111 case LDHX_OPCODE: 112 case LDWX_OPCODE: { 113 uint32_t offs = imm; 114 if (opcode >= LDBX_OPCODE) { 115 // Note: this can overflow and actually decrease offs. 116 offs += registers[1]; 117 } 118 ASSERT_IN_PACKET_BOUNDS(offs); 119 uint32_t load_size; 120 switch (opcode) { 121 case LDB_OPCODE: 122 case LDBX_OPCODE: 123 load_size = 1; 124 break; 125 case LDH_OPCODE: 126 case LDHX_OPCODE: 127 load_size = 2; 128 break; 129 case LDW_OPCODE: 130 case LDWX_OPCODE: 131 load_size = 4; 132 break; 133 // Immediately enclosing switch statement guarantees 134 // opcode cannot be any other value. 135 } 136 const uint32_t end_offs = offs + (load_size - 1); 137 // Catch overflow/wrap-around. 138 ASSERT_RETURN(end_offs >= offs); 139 ASSERT_IN_PACKET_BOUNDS(end_offs); 140 uint32_t val = 0; 141 while (load_size--) 142 val = (val << 8) | packet[offs++]; 143 REG = val; 144 break; 145 } 146 case JMP_OPCODE: 147 // This can jump backwards. Infinite looping prevented by instructions_remaining. 148 pc += imm; 149 break; 150 case JEQ_OPCODE: 151 case JNE_OPCODE: 152 case JGT_OPCODE: 153 case JLT_OPCODE: 154 case JSET_OPCODE: 155 case JNEBS_OPCODE: { 156 // Load second immediate field. 157 uint32_t cmp_imm = 0; 158 if (reg_num == 1) { 159 cmp_imm = registers[1]; 160 } else if (len_field != 0) { 161 uint32_t cmp_imm_len = 1 << (len_field - 1); 162 ASSERT_FORWARD_IN_PROGRAM(pc + cmp_imm_len - 1); 163 uint32_t i; 164 for (i = 0; i < cmp_imm_len; i++) 165 cmp_imm = (cmp_imm << 8) | program[pc++]; 166 } 167 switch (opcode) { 168 case JEQ_OPCODE: 169 if (registers[0] == cmp_imm) 170 pc += imm; 171 break; 172 case JNE_OPCODE: 173 if (registers[0] != cmp_imm) 174 pc += imm; 175 break; 176 case JGT_OPCODE: 177 if (registers[0] > cmp_imm) 178 pc += imm; 179 break; 180 case JLT_OPCODE: 181 if (registers[0] < cmp_imm) 182 pc += imm; 183 break; 184 case JSET_OPCODE: 185 if (registers[0] & cmp_imm) 186 pc += imm; 187 break; 188 case JNEBS_OPCODE: { 189 // cmp_imm is size in bytes of data to compare. 190 // pc is offset of program bytes to compare. 191 // imm is jump target offset. 192 // REG is offset of packet bytes to compare. 193 ASSERT_FORWARD_IN_PROGRAM(pc + cmp_imm - 1); 194 ASSERT_IN_PACKET_BOUNDS(REG); 195 const uint32_t last_packet_offs = REG + cmp_imm - 1; 196 ASSERT_RETURN(last_packet_offs >= REG); 197 ASSERT_IN_PACKET_BOUNDS(last_packet_offs); 198 if (memcmp(program + pc, packet + REG, cmp_imm)) 199 pc += imm; 200 // skip past comparison bytes 201 pc += cmp_imm; 202 break; 203 } 204 } 205 break; 206 } 207 case ADD_OPCODE: 208 registers[0] += reg_num ? registers[1] : imm; 209 break; 210 case MUL_OPCODE: 211 registers[0] *= reg_num ? registers[1] : imm; 212 break; 213 case DIV_OPCODE: { 214 const uint32_t div_operand = reg_num ? registers[1] : imm; 215 ASSERT_RETURN(div_operand); 216 registers[0] /= div_operand; 217 break; 218 } 219 case AND_OPCODE: 220 registers[0] &= reg_num ? registers[1] : imm; 221 break; 222 case OR_OPCODE: 223 registers[0] |= reg_num ? registers[1] : imm; 224 break; 225 case SH_OPCODE: { 226 const int32_t shift_val = reg_num ? (int32_t)registers[1] : signed_imm; 227 if (shift_val > 0) 228 registers[0] <<= shift_val; 229 else 230 registers[0] >>= -shift_val; 231 break; 232 } 233 case LI_OPCODE: 234 REG = signed_imm; 235 break; 236 case EXT_OPCODE: 237 if ( 238 // If LDM_EXT_OPCODE is 0 and imm is compared with it, a compiler error will result, 239 // instead just enforce that imm is unsigned (so it's always greater or equal to 0). 240 #if LDM_EXT_OPCODE == 0 241 ENFORCE_UNSIGNED(imm) && 242 #else 243 imm >= LDM_EXT_OPCODE && 244 #endif 245 imm < (LDM_EXT_OPCODE + MEMORY_ITEMS)) { 246 REG = memory[imm - LDM_EXT_OPCODE]; 247 } else if (imm >= STM_EXT_OPCODE && imm < (STM_EXT_OPCODE + MEMORY_ITEMS)) { 248 memory[imm - STM_EXT_OPCODE] = REG; 249 } else switch (imm) { 250 case NOT_EXT_OPCODE: 251 REG = ~REG; 252 break; 253 case NEG_EXT_OPCODE: 254 REG = -REG; 255 break; 256 case SWAP_EXT_OPCODE: { 257 uint32_t tmp = REG; 258 REG = OTHER_REG; 259 OTHER_REG = tmp; 260 break; 261 } 262 case MOV_EXT_OPCODE: 263 REG = OTHER_REG; 264 break; 265 // Unknown extended opcode 266 default: 267 // Bail out 268 return PASS_PACKET; 269 } 270 break; 271 // Unknown opcode 272 default: 273 // Bail out 274 return PASS_PACKET; 275 } 276 } while (instructions_remaining--); 277 return PASS_PACKET; 278 } 279