Home | History | Annotate | Download | only in client
      1 /*
      2  * Copyright (C) 2014 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "NetdClient.h"
     18 
     19 #include <arpa/inet.h>
     20 #include <errno.h>
     21 #include <math.h>
     22 #include <sys/socket.h>
     23 #include <unistd.h>
     24 
     25 #include <atomic>
     26 
     27 #include "Fwmark.h"
     28 #include "FwmarkClient.h"
     29 #include "FwmarkCommand.h"
     30 #include "resolv_netid.h"
     31 #include "Stopwatch.h"
     32 
     33 namespace {
     34 
     35 std::atomic_uint netIdForProcess(NETID_UNSET);
     36 std::atomic_uint netIdForResolv(NETID_UNSET);
     37 
     38 typedef int (*Accept4FunctionType)(int, sockaddr*, socklen_t*, int);
     39 typedef int (*ConnectFunctionType)(int, const sockaddr*, socklen_t);
     40 typedef int (*SocketFunctionType)(int, int, int);
     41 typedef unsigned (*NetIdForResolvFunctionType)(unsigned);
     42 
     43 // These variables are only modified at startup (when libc.so is loaded) and never afterwards, so
     44 // it's okay that they are read later at runtime without a lock.
     45 Accept4FunctionType libcAccept4 = 0;
     46 ConnectFunctionType libcConnect = 0;
     47 SocketFunctionType libcSocket = 0;
     48 
     49 int closeFdAndSetErrno(int fd, int error) {
     50     close(fd);
     51     errno = -error;
     52     return -1;
     53 }
     54 
     55 int netdClientAccept4(int sockfd, sockaddr* addr, socklen_t* addrlen, int flags) {
     56     int acceptedSocket = libcAccept4(sockfd, addr, addrlen, flags);
     57     if (acceptedSocket == -1) {
     58         return -1;
     59     }
     60     int family;
     61     if (addr) {
     62         family = addr->sa_family;
     63     } else {
     64         socklen_t familyLen = sizeof(family);
     65         if (getsockopt(acceptedSocket, SOL_SOCKET, SO_DOMAIN, &family, &familyLen) == -1) {
     66             return closeFdAndSetErrno(acceptedSocket, -errno);
     67         }
     68     }
     69     if (FwmarkClient::shouldSetFwmark(family)) {
     70         FwmarkCommand command = {FwmarkCommand::ON_ACCEPT, 0, 0};
     71         if (int error = FwmarkClient().send(&command, acceptedSocket, nullptr)) {
     72             return closeFdAndSetErrno(acceptedSocket, error);
     73         }
     74     }
     75     return acceptedSocket;
     76 }
     77 
     78 int netdClientConnect(int sockfd, const sockaddr* addr, socklen_t addrlen) {
     79     const bool shouldSetFwmark = (sockfd >= 0) && addr
     80             && FwmarkClient::shouldSetFwmark(addr->sa_family);
     81     if (shouldSetFwmark) {
     82         FwmarkCommand command = {FwmarkCommand::ON_CONNECT, 0, 0};
     83         if (int error = FwmarkClient().send(&command, sockfd, nullptr)) {
     84             errno = -error;
     85             return -1;
     86         }
     87     }
     88     // Latency measurement does not include time of sending commands to Fwmark
     89     Stopwatch s;
     90     const int ret = libcConnect(sockfd, addr, addrlen);
     91     // Save errno so it isn't clobbered by sending ON_CONNECT_COMPLETE
     92     const int connectErrno = errno;
     93     const unsigned latencyMs = lround(s.timeTaken());
     94     // Send an ON_CONNECT_COMPLETE command that includes sockaddr and connect latency for reporting
     95     if (shouldSetFwmark && FwmarkClient::shouldReportConnectComplete(addr->sa_family)) {
     96         FwmarkConnectInfo connectInfo(ret == 0 ? 0 : connectErrno, latencyMs, addr);
     97         // TODO: get the netId from the socket mark once we have continuous benchmark runs
     98         FwmarkCommand command = {FwmarkCommand::ON_CONNECT_COMPLETE, /* netId (ignored) */ 0,
     99                 /* uid (filled in by the server) */ 0};
    100         // Ignore return value since it's only used for logging
    101         FwmarkClient().send(&command, sockfd, &connectInfo);
    102     }
    103     errno = connectErrno;
    104     return ret;
    105 }
    106 
    107 int netdClientSocket(int domain, int type, int protocol) {
    108     int socketFd = libcSocket(domain, type, protocol);
    109     if (socketFd == -1) {
    110         return -1;
    111     }
    112     unsigned netId = netIdForProcess;
    113     if (netId != NETID_UNSET && FwmarkClient::shouldSetFwmark(domain)) {
    114         if (int error = setNetworkForSocket(netId, socketFd)) {
    115             return closeFdAndSetErrno(socketFd, error);
    116         }
    117     }
    118     return socketFd;
    119 }
    120 
    121 unsigned getNetworkForResolv(unsigned netId) {
    122     if (netId != NETID_UNSET) {
    123         return netId;
    124     }
    125     netId = netIdForProcess;
    126     if (netId != NETID_UNSET) {
    127         return netId;
    128     }
    129     return netIdForResolv;
    130 }
    131 
    132 int setNetworkForTarget(unsigned netId, std::atomic_uint* target) {
    133     if (netId == NETID_UNSET) {
    134         *target = netId;
    135         return 0;
    136     }
    137     // Verify that we are allowed to use |netId|, by creating a socket and trying to have it marked
    138     // with the netId. Call libcSocket() directly; else the socket creation (via netdClientSocket())
    139     // might itself cause another check with the fwmark server, which would be wasteful.
    140     int socketFd;
    141     if (libcSocket) {
    142         socketFd = libcSocket(AF_INET6, SOCK_DGRAM | SOCK_CLOEXEC, 0);
    143     } else {
    144         socketFd = socket(AF_INET6, SOCK_DGRAM | SOCK_CLOEXEC, 0);
    145     }
    146     if (socketFd < 0) {
    147         return -errno;
    148     }
    149     int error = setNetworkForSocket(netId, socketFd);
    150     if (!error) {
    151         *target = netId;
    152     }
    153     close(socketFd);
    154     return error;
    155 }
    156 
    157 }  // namespace
    158 
    159 // accept() just calls accept4(..., 0), so there's no need to handle accept() separately.
    160 extern "C" void netdClientInitAccept4(Accept4FunctionType* function) {
    161     if (function && *function) {
    162         libcAccept4 = *function;
    163         *function = netdClientAccept4;
    164     }
    165 }
    166 
    167 extern "C" void netdClientInitConnect(ConnectFunctionType* function) {
    168     if (function && *function) {
    169         libcConnect = *function;
    170         *function = netdClientConnect;
    171     }
    172 }
    173 
    174 extern "C" void netdClientInitSocket(SocketFunctionType* function) {
    175     if (function && *function) {
    176         libcSocket = *function;
    177         *function = netdClientSocket;
    178     }
    179 }
    180 
    181 extern "C" void netdClientInitNetIdForResolv(NetIdForResolvFunctionType* function) {
    182     if (function) {
    183         *function = getNetworkForResolv;
    184     }
    185 }
    186 
    187 extern "C" int getNetworkForSocket(unsigned* netId, int socketFd) {
    188     if (!netId || socketFd < 0) {
    189         return -EBADF;
    190     }
    191     Fwmark fwmark;
    192     socklen_t fwmarkLen = sizeof(fwmark.intValue);
    193     if (getsockopt(socketFd, SOL_SOCKET, SO_MARK, &fwmark.intValue, &fwmarkLen) == -1) {
    194         return -errno;
    195     }
    196     *netId = fwmark.netId;
    197     return 0;
    198 }
    199 
    200 extern "C" unsigned getNetworkForProcess() {
    201     return netIdForProcess;
    202 }
    203 
    204 extern "C" int setNetworkForSocket(unsigned netId, int socketFd) {
    205     if (socketFd < 0) {
    206         return -EBADF;
    207     }
    208     FwmarkCommand command = {FwmarkCommand::SELECT_NETWORK, netId, 0};
    209     return FwmarkClient().send(&command, socketFd, nullptr);
    210 }
    211 
    212 extern "C" int setNetworkForProcess(unsigned netId) {
    213     return setNetworkForTarget(netId, &netIdForProcess);
    214 }
    215 
    216 extern "C" int setNetworkForResolv(unsigned netId) {
    217     return setNetworkForTarget(netId, &netIdForResolv);
    218 }
    219 
    220 extern "C" int protectFromVpn(int socketFd) {
    221     if (socketFd < 0) {
    222         return -EBADF;
    223     }
    224     FwmarkCommand command = {FwmarkCommand::PROTECT_FROM_VPN, 0, 0};
    225     return FwmarkClient().send(&command, socketFd, nullptr);
    226 }
    227 
    228 extern "C" int setNetworkForUser(uid_t uid, int socketFd) {
    229     if (socketFd < 0) {
    230         return -EBADF;
    231     }
    232     FwmarkCommand command = {FwmarkCommand::SELECT_FOR_USER, 0, uid};
    233     return FwmarkClient().send(&command, socketFd, nullptr);
    234 }
    235 
    236 extern "C" int queryUserAccess(uid_t uid, unsigned netId) {
    237     FwmarkCommand command = {FwmarkCommand::QUERY_USER_ACCESS, netId, uid};
    238     return FwmarkClient().send(&command, -1, nullptr);
    239 }
    240