Home | History | Annotate | Download | only in client
      1 /*
      2  * Copyright (C) 2014 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "NetdClient.h"
     18 
     19 #include "Fwmark.h"
     20 #include "FwmarkClient.h"
     21 #include "FwmarkCommand.h"
     22 #include "resolv_netid.h"
     23 
     24 #include <atomic>
     25 #include <sys/socket.h>
     26 #include <unistd.h>
     27 
     28 namespace {
     29 
     30 std::atomic_uint netIdForProcess(NETID_UNSET);
     31 std::atomic_uint netIdForResolv(NETID_UNSET);
     32 
     33 typedef int (*Accept4FunctionType)(int, sockaddr*, socklen_t*, int);
     34 typedef int (*ConnectFunctionType)(int, const sockaddr*, socklen_t);
     35 typedef int (*SocketFunctionType)(int, int, int);
     36 typedef unsigned (*NetIdForResolvFunctionType)(unsigned);
     37 
     38 // These variables are only modified at startup (when libc.so is loaded) and never afterwards, so
     39 // it's okay that they are read later at runtime without a lock.
     40 Accept4FunctionType libcAccept4 = 0;
     41 ConnectFunctionType libcConnect = 0;
     42 SocketFunctionType libcSocket = 0;
     43 
     44 int closeFdAndSetErrno(int fd, int error) {
     45     close(fd);
     46     errno = -error;
     47     return -1;
     48 }
     49 
     50 int netdClientAccept4(int sockfd, sockaddr* addr, socklen_t* addrlen, int flags) {
     51     int acceptedSocket = libcAccept4(sockfd, addr, addrlen, flags);
     52     if (acceptedSocket == -1) {
     53         return -1;
     54     }
     55     int family;
     56     if (addr) {
     57         family = addr->sa_family;
     58     } else {
     59         socklen_t familyLen = sizeof(family);
     60         if (getsockopt(acceptedSocket, SOL_SOCKET, SO_DOMAIN, &family, &familyLen) == -1) {
     61             return closeFdAndSetErrno(acceptedSocket, -errno);
     62         }
     63     }
     64     if (FwmarkClient::shouldSetFwmark(family)) {
     65         FwmarkCommand command = {FwmarkCommand::ON_ACCEPT, 0, 0};
     66         if (int error = FwmarkClient().send(&command, sizeof(command), acceptedSocket)) {
     67             return closeFdAndSetErrno(acceptedSocket, error);
     68         }
     69     }
     70     return acceptedSocket;
     71 }
     72 
     73 int netdClientConnect(int sockfd, const sockaddr* addr, socklen_t addrlen) {
     74     if (sockfd >= 0 && addr && FwmarkClient::shouldSetFwmark(addr->sa_family)) {
     75         FwmarkCommand command = {FwmarkCommand::ON_CONNECT, 0, 0};
     76         if (int error = FwmarkClient().send(&command, sizeof(command), sockfd)) {
     77             errno = -error;
     78             return -1;
     79         }
     80     }
     81     return libcConnect(sockfd, addr, addrlen);
     82 }
     83 
     84 int netdClientSocket(int domain, int type, int protocol) {
     85     int socketFd = libcSocket(domain, type, protocol);
     86     if (socketFd == -1) {
     87         return -1;
     88     }
     89     unsigned netId = netIdForProcess;
     90     if (netId != NETID_UNSET && FwmarkClient::shouldSetFwmark(domain)) {
     91         if (int error = setNetworkForSocket(netId, socketFd)) {
     92             return closeFdAndSetErrno(socketFd, error);
     93         }
     94     }
     95     return socketFd;
     96 }
     97 
     98 unsigned getNetworkForResolv(unsigned netId) {
     99     if (netId != NETID_UNSET) {
    100         return netId;
    101     }
    102     netId = netIdForProcess;
    103     if (netId != NETID_UNSET) {
    104         return netId;
    105     }
    106     return netIdForResolv;
    107 }
    108 
    109 int setNetworkForTarget(unsigned netId, std::atomic_uint* target) {
    110     if (netId == NETID_UNSET) {
    111         *target = netId;
    112         return 0;
    113     }
    114     // Verify that we are allowed to use |netId|, by creating a socket and trying to have it marked
    115     // with the netId. Call libcSocket() directly; else the socket creation (via netdClientSocket())
    116     // might itself cause another check with the fwmark server, which would be wasteful.
    117     int socketFd;
    118     if (libcSocket) {
    119         socketFd = libcSocket(AF_INET6, SOCK_DGRAM, 0);
    120     } else {
    121         socketFd = socket(AF_INET6, SOCK_DGRAM, 0);
    122     }
    123     if (socketFd < 0) {
    124         return -errno;
    125     }
    126     int error = setNetworkForSocket(netId, socketFd);
    127     if (!error) {
    128         *target = netId;
    129     }
    130     close(socketFd);
    131     return error;
    132 }
    133 
    134 }  // namespace
    135 
    136 // accept() just calls accept4(..., 0), so there's no need to handle accept() separately.
    137 extern "C" void netdClientInitAccept4(Accept4FunctionType* function) {
    138     if (function && *function) {
    139         libcAccept4 = *function;
    140         *function = netdClientAccept4;
    141     }
    142 }
    143 
    144 extern "C" void netdClientInitConnect(ConnectFunctionType* function) {
    145     if (function && *function) {
    146         libcConnect = *function;
    147         *function = netdClientConnect;
    148     }
    149 }
    150 
    151 extern "C" void netdClientInitSocket(SocketFunctionType* function) {
    152     if (function && *function) {
    153         libcSocket = *function;
    154         *function = netdClientSocket;
    155     }
    156 }
    157 
    158 extern "C" void netdClientInitNetIdForResolv(NetIdForResolvFunctionType* function) {
    159     if (function) {
    160         *function = getNetworkForResolv;
    161     }
    162 }
    163 
    164 extern "C" int getNetworkForSocket(unsigned* netId, int socketFd) {
    165     if (!netId || socketFd < 0) {
    166         return -EBADF;
    167     }
    168     Fwmark fwmark;
    169     socklen_t fwmarkLen = sizeof(fwmark.intValue);
    170     if (getsockopt(socketFd, SOL_SOCKET, SO_MARK, &fwmark.intValue, &fwmarkLen) == -1) {
    171         return -errno;
    172     }
    173     *netId = fwmark.netId;
    174     return 0;
    175 }
    176 
    177 extern "C" unsigned getNetworkForProcess() {
    178     return netIdForProcess;
    179 }
    180 
    181 extern "C" int setNetworkForSocket(unsigned netId, int socketFd) {
    182     if (socketFd < 0) {
    183         return -EBADF;
    184     }
    185     FwmarkCommand command = {FwmarkCommand::SELECT_NETWORK, netId, 0};
    186     return FwmarkClient().send(&command, sizeof(command), socketFd);
    187 }
    188 
    189 extern "C" int setNetworkForProcess(unsigned netId) {
    190     return setNetworkForTarget(netId, &netIdForProcess);
    191 }
    192 
    193 extern "C" int setNetworkForResolv(unsigned netId) {
    194     return setNetworkForTarget(netId, &netIdForResolv);
    195 }
    196 
    197 extern "C" int protectFromVpn(int socketFd) {
    198     if (socketFd < 0) {
    199         return -EBADF;
    200     }
    201     FwmarkCommand command = {FwmarkCommand::PROTECT_FROM_VPN, 0, 0};
    202     return FwmarkClient().send(&command, sizeof(command), socketFd);
    203 }
    204 
    205 extern "C" int setNetworkForUser(uid_t uid, int socketFd) {
    206     if (socketFd < 0) {
    207         return -EBADF;
    208     }
    209     FwmarkCommand command = {FwmarkCommand::SELECT_FOR_USER, 0, uid};
    210     return FwmarkClient().send(&command, sizeof(command), socketFd);
    211 }
    212