Home | History | Annotate | Download | only in net
      1 // SPDX-License-Identifier: GPL-2.0
      2 /*
      3  * Test functionality of BPF filters with SO_REUSEPORT. Same test as
      4  * in reuseport_bpf_cpu, only as one socket per NUMA node.
      5  */
      6 
      7 #define _GNU_SOURCE
      8 
      9 #include <arpa/inet.h>
     10 #include <errno.h>
     11 #include <error.h>
     12 #include <linux/filter.h>
     13 #include <linux/bpf.h>
     14 #include <linux/in.h>
     15 #include <linux/unistd.h>
     16 #include <sched.h>
     17 #include <stdio.h>
     18 #include <stdlib.h>
     19 #include <string.h>
     20 #include <sys/epoll.h>
     21 #include <sys/types.h>
     22 #include <sys/socket.h>
     23 #include <unistd.h>
     24 #include <numa.h>
     25 
     26 static const int PORT = 8888;
     27 
     28 static void build_rcv_group(int *rcv_fd, size_t len, int family, int proto)
     29 {
     30 	struct sockaddr_storage addr;
     31 	struct sockaddr_in  *addr4;
     32 	struct sockaddr_in6 *addr6;
     33 	size_t i;
     34 	int opt;
     35 
     36 	switch (family) {
     37 	case AF_INET:
     38 		addr4 = (struct sockaddr_in *)&addr;
     39 		addr4->sin_family = AF_INET;
     40 		addr4->sin_addr.s_addr = htonl(INADDR_ANY);
     41 		addr4->sin_port = htons(PORT);
     42 		break;
     43 	case AF_INET6:
     44 		addr6 = (struct sockaddr_in6 *)&addr;
     45 		addr6->sin6_family = AF_INET6;
     46 		addr6->sin6_addr = in6addr_any;
     47 		addr6->sin6_port = htons(PORT);
     48 		break;
     49 	default:
     50 		error(1, 0, "Unsupported family %d", family);
     51 	}
     52 
     53 	for (i = 0; i < len; ++i) {
     54 		rcv_fd[i] = socket(family, proto, 0);
     55 		if (rcv_fd[i] < 0)
     56 			error(1, errno, "failed to create receive socket");
     57 
     58 		opt = 1;
     59 		if (setsockopt(rcv_fd[i], SOL_SOCKET, SO_REUSEPORT, &opt,
     60 			       sizeof(opt)))
     61 			error(1, errno, "failed to set SO_REUSEPORT");
     62 
     63 		if (bind(rcv_fd[i], (struct sockaddr *)&addr, sizeof(addr)))
     64 			error(1, errno, "failed to bind receive socket");
     65 
     66 		if (proto == SOCK_STREAM && listen(rcv_fd[i], len * 10))
     67 			error(1, errno, "failed to listen on receive port");
     68 	}
     69 }
     70 
     71 static void attach_bpf(int fd)
     72 {
     73 	static char bpf_log_buf[65536];
     74 	static const char bpf_license[] = "";
     75 
     76 	int bpf_fd;
     77 	const struct bpf_insn prog[] = {
     78 		/* R0 = bpf_get_numa_node_id() */
     79 		{ BPF_JMP | BPF_CALL, 0, 0, 0, BPF_FUNC_get_numa_node_id },
     80 		/* return R0 */
     81 		{ BPF_JMP | BPF_EXIT, 0, 0, 0, 0 }
     82 	};
     83 	union bpf_attr attr;
     84 
     85 	memset(&attr, 0, sizeof(attr));
     86 	attr.prog_type = BPF_PROG_TYPE_SOCKET_FILTER;
     87 	attr.insn_cnt = sizeof(prog) / sizeof(prog[0]);
     88 	attr.insns = (unsigned long) &prog;
     89 	attr.license = (unsigned long) &bpf_license;
     90 	attr.log_buf = (unsigned long) &bpf_log_buf;
     91 	attr.log_size = sizeof(bpf_log_buf);
     92 	attr.log_level = 1;
     93 
     94 	bpf_fd = syscall(__NR_bpf, BPF_PROG_LOAD, &attr, sizeof(attr));
     95 	if (bpf_fd < 0)
     96 		error(1, errno, "ebpf error. log:\n%s\n", bpf_log_buf);
     97 
     98 	if (setsockopt(fd, SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF, &bpf_fd,
     99 			sizeof(bpf_fd)))
    100 		error(1, errno, "failed to set SO_ATTACH_REUSEPORT_EBPF");
    101 
    102 	close(bpf_fd);
    103 }
    104 
    105 static void send_from_node(int node_id, int family, int proto)
    106 {
    107 	struct sockaddr_storage saddr, daddr;
    108 	struct sockaddr_in  *saddr4, *daddr4;
    109 	struct sockaddr_in6 *saddr6, *daddr6;
    110 	int fd;
    111 
    112 	switch (family) {
    113 	case AF_INET:
    114 		saddr4 = (struct sockaddr_in *)&saddr;
    115 		saddr4->sin_family = AF_INET;
    116 		saddr4->sin_addr.s_addr = htonl(INADDR_ANY);
    117 		saddr4->sin_port = 0;
    118 
    119 		daddr4 = (struct sockaddr_in *)&daddr;
    120 		daddr4->sin_family = AF_INET;
    121 		daddr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
    122 		daddr4->sin_port = htons(PORT);
    123 		break;
    124 	case AF_INET6:
    125 		saddr6 = (struct sockaddr_in6 *)&saddr;
    126 		saddr6->sin6_family = AF_INET6;
    127 		saddr6->sin6_addr = in6addr_any;
    128 		saddr6->sin6_port = 0;
    129 
    130 		daddr6 = (struct sockaddr_in6 *)&daddr;
    131 		daddr6->sin6_family = AF_INET6;
    132 		daddr6->sin6_addr = in6addr_loopback;
    133 		daddr6->sin6_port = htons(PORT);
    134 		break;
    135 	default:
    136 		error(1, 0, "Unsupported family %d", family);
    137 	}
    138 
    139 	if (numa_run_on_node(node_id) < 0)
    140 		error(1, errno, "failed to pin to node");
    141 
    142 	fd = socket(family, proto, 0);
    143 	if (fd < 0)
    144 		error(1, errno, "failed to create send socket");
    145 
    146 	if (bind(fd, (struct sockaddr *)&saddr, sizeof(saddr)))
    147 		error(1, errno, "failed to bind send socket");
    148 
    149 	if (connect(fd, (struct sockaddr *)&daddr, sizeof(daddr)))
    150 		error(1, errno, "failed to connect send socket");
    151 
    152 	if (send(fd, "a", 1, 0) < 0)
    153 		error(1, errno, "failed to send message");
    154 
    155 	close(fd);
    156 }
    157 
    158 static
    159 void receive_on_node(int *rcv_fd, int len, int epfd, int node_id, int proto)
    160 {
    161 	struct epoll_event ev;
    162 	int i, fd;
    163 	char buf[8];
    164 
    165 	i = epoll_wait(epfd, &ev, 1, -1);
    166 	if (i < 0)
    167 		error(1, errno, "epoll_wait failed");
    168 
    169 	if (proto == SOCK_STREAM) {
    170 		fd = accept(ev.data.fd, NULL, NULL);
    171 		if (fd < 0)
    172 			error(1, errno, "failed to accept");
    173 		i = recv(fd, buf, sizeof(buf), 0);
    174 		close(fd);
    175 	} else {
    176 		i = recv(ev.data.fd, buf, sizeof(buf), 0);
    177 	}
    178 
    179 	if (i < 0)
    180 		error(1, errno, "failed to recv");
    181 
    182 	for (i = 0; i < len; ++i)
    183 		if (ev.data.fd == rcv_fd[i])
    184 			break;
    185 	if (i == len)
    186 		error(1, 0, "failed to find socket");
    187 	fprintf(stderr, "send node %d, receive socket %d\n", node_id, i);
    188 	if (node_id != i)
    189 		error(1, 0, "node id/receive socket mismatch");
    190 }
    191 
    192 static void test(int *rcv_fd, int len, int family, int proto)
    193 {
    194 	struct epoll_event ev;
    195 	int epfd, node;
    196 
    197 	build_rcv_group(rcv_fd, len, family, proto);
    198 	attach_bpf(rcv_fd[0]);
    199 
    200 	epfd = epoll_create(1);
    201 	if (epfd < 0)
    202 		error(1, errno, "failed to create epoll");
    203 	for (node = 0; node < len; ++node) {
    204 		ev.events = EPOLLIN;
    205 		ev.data.fd = rcv_fd[node];
    206 		if (epoll_ctl(epfd, EPOLL_CTL_ADD, rcv_fd[node], &ev))
    207 			error(1, errno, "failed to register sock epoll");
    208 	}
    209 
    210 	/* Forward iterate */
    211 	for (node = 0; node < len; ++node) {
    212 		send_from_node(node, family, proto);
    213 		receive_on_node(rcv_fd, len, epfd, node, proto);
    214 	}
    215 
    216 	/* Reverse iterate */
    217 	for (node = len - 1; node >= 0; --node) {
    218 		send_from_node(node, family, proto);
    219 		receive_on_node(rcv_fd, len, epfd, node, proto);
    220 	}
    221 
    222 	close(epfd);
    223 	for (node = 0; node < len; ++node)
    224 		close(rcv_fd[node]);
    225 }
    226 
    227 int main(void)
    228 {
    229 	int *rcv_fd, nodes;
    230 
    231 	if (numa_available() < 0)
    232 		error(1, errno, "no numa api support");
    233 
    234 	nodes = numa_max_node() + 1;
    235 
    236 	rcv_fd = calloc(nodes, sizeof(int));
    237 	if (!rcv_fd)
    238 		error(1, 0, "failed to allocate array");
    239 
    240 	fprintf(stderr, "---- IPv4 UDP ----\n");
    241 	test(rcv_fd, nodes, AF_INET, SOCK_DGRAM);
    242 
    243 	fprintf(stderr, "---- IPv6 UDP ----\n");
    244 	test(rcv_fd, nodes, AF_INET6, SOCK_DGRAM);
    245 
    246 	fprintf(stderr, "---- IPv4 TCP ----\n");
    247 	test(rcv_fd, nodes, AF_INET, SOCK_STREAM);
    248 
    249 	fprintf(stderr, "---- IPv6 TCP ----\n");
    250 	test(rcv_fd, nodes, AF_INET6, SOCK_STREAM);
    251 
    252 	free(rcv_fd);
    253 
    254 	fprintf(stderr, "SUCCESS\n");
    255 	return 0;
    256 }
    257