Home | History | Annotate | Download | only in server
      1 /*
      2  * Copyright (C) 2016 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include <errno.h>
     18 #include <netdb.h>
     19 #include <string.h>
     20 #include <netinet/in.h>
     21 #include <netinet/tcp.h>
     22 #include <sys/socket.h>
     23 #include <sys/uio.h>
     24 
     25 #include <linux/netlink.h>
     26 #include <linux/sock_diag.h>
     27 #include <linux/inet_diag.h>
     28 
     29 #define LOG_TAG "Netd"
     30 
     31 #include <android-base/strings.h>
     32 #include <cutils/log.h>
     33 
     34 #include "NetdConstants.h"
     35 #include "Permission.h"
     36 #include "SockDiag.h"
     37 #include "Stopwatch.h"
     38 
     39 #include <chrono>
     40 
     41 #ifndef SOCK_DESTROY
     42 #define SOCK_DESTROY 21
     43 #endif
     44 
     45 #define INET_DIAG_BC_MARK_COND 10
     46 
     47 namespace android {
     48 namespace net {
     49 
     50 namespace {
     51 
     52 int checkError(int fd) {
     53     struct {
     54         nlmsghdr h;
     55         nlmsgerr err;
     56     } __attribute__((__packed__)) ack;
     57     ssize_t bytesread = recv(fd, &ack, sizeof(ack), MSG_DONTWAIT | MSG_PEEK);
     58     if (bytesread == -1) {
     59        // Read failed (error), or nothing to read (good).
     60        return (errno == EAGAIN) ? 0 : -errno;
     61     } else if (bytesread == (ssize_t) sizeof(ack) && ack.h.nlmsg_type == NLMSG_ERROR) {
     62         // We got an error. Consume it.
     63         recv(fd, &ack, sizeof(ack), 0);
     64         return ack.err.error;
     65     } else {
     66         // The kernel replied with something. Leave it to the caller.
     67         return 0;
     68     }
     69 }
     70 
     71 }  // namespace
     72 
     73 bool SockDiag::open() {
     74     if (hasSocks()) {
     75         return false;
     76     }
     77 
     78     mSock = socket(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_INET_DIAG);
     79     mWriteSock = socket(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_INET_DIAG);
     80     if (!hasSocks()) {
     81         closeSocks();
     82         return false;
     83     }
     84 
     85     sockaddr_nl nl = { .nl_family = AF_NETLINK };
     86     if ((connect(mSock, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) == -1) ||
     87         (connect(mWriteSock, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) == -1)) {
     88         closeSocks();
     89         return false;
     90     }
     91 
     92     return true;
     93 }
     94 
     95 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint8_t extensions, uint32_t states,
     96                               iovec *iov, int iovcnt) {
     97     struct {
     98         nlmsghdr nlh;
     99         inet_diag_req_v2 req;
    100     } __attribute__((__packed__)) request = {
    101         .nlh = {
    102             .nlmsg_type = SOCK_DIAG_BY_FAMILY,
    103             .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
    104         },
    105         .req = {
    106             .sdiag_family = family,
    107             .sdiag_protocol = proto,
    108             .idiag_ext = extensions,
    109             .idiag_states = states,
    110         },
    111     };
    112 
    113     size_t len = 0;
    114     iov[0].iov_base = &request;
    115     iov[0].iov_len = sizeof(request);
    116     for (int i = 0; i < iovcnt; i++) {
    117         len += iov[i].iov_len;
    118     }
    119     request.nlh.nlmsg_len = len;
    120 
    121     if (writev(mSock, iov, iovcnt) != (ssize_t) len) {
    122         return -errno;
    123     }
    124 
    125     return checkError(mSock);
    126 }
    127 
    128 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states) {
    129     iovec iov[] = {
    130         { nullptr, 0 },
    131     };
    132     return sendDumpRequest(proto, family, 0, states, iov, ARRAY_SIZE(iov));
    133 }
    134 
    135 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, const char *addrstr) {
    136     addrinfo hints = { .ai_flags = AI_NUMERICHOST };
    137     addrinfo *res;
    138     in6_addr mapped = { .s6_addr32 = { 0, 0, htonl(0xffff), 0 } };
    139     int ret;
    140 
    141     // TODO: refactor the netlink parsing code out of system/core, bring it into netd, and stop
    142     // doing string conversions when they're not necessary.
    143     if ((ret = getaddrinfo(addrstr, nullptr, &hints, &res)) != 0) {
    144         return -EINVAL;
    145     }
    146 
    147     // So we don't have to call freeaddrinfo on every failure path.
    148     ScopedAddrinfo resP(res);
    149 
    150     void *addr;
    151     uint8_t addrlen;
    152     if (res->ai_family == AF_INET && family == AF_INET) {
    153         in_addr& ina = reinterpret_cast<sockaddr_in*>(res->ai_addr)->sin_addr;
    154         addr = &ina;
    155         addrlen = sizeof(ina);
    156     } else if (res->ai_family == AF_INET && family == AF_INET6) {
    157         in_addr& ina = reinterpret_cast<sockaddr_in*>(res->ai_addr)->sin_addr;
    158         mapped.s6_addr32[3] = ina.s_addr;
    159         addr = &mapped;
    160         addrlen = sizeof(mapped);
    161     } else if (res->ai_family == AF_INET6 && family == AF_INET6) {
    162         in6_addr& in6a = reinterpret_cast<sockaddr_in6*>(res->ai_addr)->sin6_addr;
    163         addr = &in6a;
    164         addrlen = sizeof(in6a);
    165     } else {
    166         return -EAFNOSUPPORT;
    167     }
    168 
    169     uint8_t prefixlen = addrlen * 8;
    170     uint8_t yesjump = sizeof(inet_diag_bc_op) + sizeof(inet_diag_hostcond) + addrlen;
    171     uint8_t nojump = yesjump + 4;
    172 
    173     struct {
    174         nlattr nla;
    175         inet_diag_bc_op op;
    176         inet_diag_hostcond cond;
    177     } __attribute__((__packed__)) attrs = {
    178         .nla = {
    179             .nla_type = INET_DIAG_REQ_BYTECODE,
    180         },
    181         .op = {
    182             INET_DIAG_BC_S_COND,
    183             yesjump,
    184             nojump,
    185         },
    186         .cond = {
    187             family,
    188             prefixlen,
    189             -1,
    190             {}
    191         },
    192     };
    193 
    194     attrs.nla.nla_len = sizeof(attrs) + addrlen;
    195 
    196     iovec iov[] = {
    197         { nullptr,           0 },
    198         { &attrs,            sizeof(attrs) },
    199         { addr,              addrlen },
    200     };
    201 
    202     uint32_t states = ~(1 << TCP_TIME_WAIT);
    203     return sendDumpRequest(proto, family, 0, states, iov, ARRAY_SIZE(iov));
    204 }
    205 
    206 int SockDiag::readDiagMsg(uint8_t proto, const SockDiag::DestroyFilter& shouldDestroy) {
    207     NetlinkDumpCallback callback = [this, proto, shouldDestroy] (nlmsghdr *nlh) {
    208         const inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
    209         if (shouldDestroy(proto, msg)) {
    210             sockDestroy(proto, msg);
    211         }
    212     };
    213 
    214     return processNetlinkDump(mSock, callback);
    215 }
    216 
    217 int SockDiag::readDiagMsgWithTcpInfo(const TcpInfoReader& tcpInfoReader) {
    218     NetlinkDumpCallback callback = [tcpInfoReader] (nlmsghdr *nlh) {
    219         if (nlh->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
    220             ALOGE("expected nlmsg_type=SOCK_DIAG_BY_FAMILY, got nlmsg_type=%d", nlh->nlmsg_type);
    221             return;
    222         }
    223         Fwmark mark;
    224         struct tcp_info *tcpinfo = nullptr;
    225         uint32_t tcpinfoLength = 0;
    226         inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
    227         uint32_t attr_len = nlh->nlmsg_len - NLMSG_LENGTH(sizeof(*msg));
    228         struct rtattr *attr = reinterpret_cast<struct rtattr*>(msg+1);
    229         while (RTA_OK(attr, attr_len)) {
    230             if (attr->rta_type == INET_DIAG_INFO) {
    231                 tcpinfo = reinterpret_cast<struct tcp_info*>(RTA_DATA(attr));
    232                 tcpinfoLength = RTA_PAYLOAD(attr);
    233             }
    234             if (attr->rta_type == INET_DIAG_MARK) {
    235                 mark.intValue = *reinterpret_cast<uint32_t*>(RTA_DATA(attr));
    236             }
    237             attr = RTA_NEXT(attr, attr_len);
    238         }
    239 
    240         tcpInfoReader(mark, msg, tcpinfo, tcpinfoLength);
    241     };
    242 
    243     return processNetlinkDump(mSock, callback);
    244 }
    245 
    246 // Determines whether a socket is a loopback socket. Does not check socket state.
    247 bool SockDiag::isLoopbackSocket(const inet_diag_msg *msg) {
    248     switch (msg->idiag_family) {
    249         case AF_INET:
    250             // Old kernels only copy the IPv4 address and leave the other 12 bytes uninitialized.
    251             return IN_LOOPBACK(htonl(msg->id.idiag_src[0])) ||
    252                    IN_LOOPBACK(htonl(msg->id.idiag_dst[0])) ||
    253                    msg->id.idiag_src[0] == msg->id.idiag_dst[0];
    254 
    255         case AF_INET6: {
    256             const struct in6_addr *src = (const struct in6_addr *) &msg->id.idiag_src;
    257             const struct in6_addr *dst = (const struct in6_addr *) &msg->id.idiag_dst;
    258             return (IN6_IS_ADDR_V4MAPPED(src) && IN_LOOPBACK(src->s6_addr32[3])) ||
    259                    (IN6_IS_ADDR_V4MAPPED(dst) && IN_LOOPBACK(dst->s6_addr32[3])) ||
    260                    IN6_IS_ADDR_LOOPBACK(src) || IN6_IS_ADDR_LOOPBACK(dst) ||
    261                    !memcmp(src, dst, sizeof(*src));
    262         }
    263         default:
    264             return false;
    265     }
    266 }
    267 
    268 int SockDiag::sockDestroy(uint8_t proto, const inet_diag_msg *msg) {
    269     if (msg == nullptr) {
    270        return 0;
    271     }
    272 
    273     DestroyRequest request = {
    274         .nlh = {
    275             .nlmsg_type = SOCK_DESTROY,
    276             .nlmsg_flags = NLM_F_REQUEST,
    277         },
    278         .req = {
    279             .sdiag_family = msg->idiag_family,
    280             .sdiag_protocol = proto,
    281             .idiag_states = (uint32_t) (1 << msg->idiag_state),
    282             .id = msg->id,
    283         },
    284     };
    285     request.nlh.nlmsg_len = sizeof(request);
    286 
    287     if (write(mWriteSock, &request, sizeof(request)) < (ssize_t) sizeof(request)) {
    288         return -errno;
    289     }
    290 
    291     int ret = checkError(mWriteSock);
    292     if (!ret) mSocketsDestroyed++;
    293     return ret;
    294 }
    295 
    296 int SockDiag::destroySockets(uint8_t proto, int family, const char *addrstr) {
    297     if (!hasSocks()) {
    298         return -EBADFD;
    299     }
    300 
    301     if (int ret = sendDumpRequest(proto, family, addrstr)) {
    302         return ret;
    303     }
    304 
    305     auto destroyAll = [] (uint8_t, const inet_diag_msg*) { return true; };
    306 
    307     return readDiagMsg(proto, destroyAll);
    308 }
    309 
    310 int SockDiag::destroySockets(const char *addrstr) {
    311     Stopwatch s;
    312     mSocketsDestroyed = 0;
    313 
    314     if (!strchr(addrstr, ':')) {
    315         if (int ret = destroySockets(IPPROTO_TCP, AF_INET, addrstr)) {
    316             ALOGE("Failed to destroy IPv4 sockets on %s: %s", addrstr, strerror(-ret));
    317             return ret;
    318         }
    319     }
    320     if (int ret = destroySockets(IPPROTO_TCP, AF_INET6, addrstr)) {
    321         ALOGE("Failed to destroy IPv6 sockets on %s: %s", addrstr, strerror(-ret));
    322         return ret;
    323     }
    324 
    325     if (mSocketsDestroyed > 0) {
    326         ALOGI("Destroyed %d sockets on %s in %.1f ms", mSocketsDestroyed, addrstr, s.timeTaken());
    327     }
    328 
    329     return mSocketsDestroyed;
    330 }
    331 
    332 int SockDiag::destroyLiveSockets(DestroyFilter destroyFilter, const char *what,
    333                                  iovec *iov, int iovcnt) {
    334     const int proto = IPPROTO_TCP;
    335     const uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
    336 
    337     for (const int family : {AF_INET, AF_INET6}) {
    338         const char *familyName = (family == AF_INET) ? "IPv4" : "IPv6";
    339         if (int ret = sendDumpRequest(proto, family, 0, states, iov, iovcnt)) {
    340             ALOGE("Failed to dump %s sockets for %s: %s", familyName, what, strerror(-ret));
    341             return ret;
    342         }
    343         if (int ret = readDiagMsg(proto, destroyFilter)) {
    344             ALOGE("Failed to destroy %s sockets for %s: %s", familyName, what, strerror(-ret));
    345             return ret;
    346         }
    347     }
    348 
    349     return 0;
    350 }
    351 
    352 int SockDiag::getLiveTcpInfos(const TcpInfoReader& tcpInfoReader) {
    353     const int proto = IPPROTO_TCP;
    354     const uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
    355     const uint8_t extensions = (1 << INET_DIAG_MEMINFO); // flag for dumping struct tcp_info.
    356 
    357     iovec iov[] = {
    358         { nullptr, 0 },
    359     };
    360 
    361     for (const int family : {AF_INET, AF_INET6}) {
    362         const char *familyName = (family == AF_INET) ? "IPv4" : "IPv6";
    363         if (int ret = sendDumpRequest(proto, family, extensions, states, iov, ARRAY_SIZE(iov))) {
    364             ALOGE("Failed to dump %s sockets struct tcp_info: %s", familyName, strerror(-ret));
    365             return ret;
    366         }
    367         if (int ret = readDiagMsgWithTcpInfo(tcpInfoReader)) {
    368             ALOGE("Failed to read %s sockets struct tcp_info: %s", familyName, strerror(-ret));
    369             return ret;
    370         }
    371     }
    372 
    373     return 0;
    374 }
    375 
    376 int SockDiag::destroySockets(uint8_t proto, const uid_t uid, bool excludeLoopback) {
    377     mSocketsDestroyed = 0;
    378     Stopwatch s;
    379 
    380     auto shouldDestroy = [uid, excludeLoopback] (uint8_t, const inet_diag_msg *msg) {
    381         return msg != nullptr &&
    382                msg->idiag_uid == uid &&
    383                !(excludeLoopback && isLoopbackSocket(msg));
    384     };
    385 
    386     for (const int family : {AF_INET, AF_INET6}) {
    387         const char *familyName = family == AF_INET ? "IPv4" : "IPv6";
    388         uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
    389         if (int ret = sendDumpRequest(proto, family, states)) {
    390             ALOGE("Failed to dump %s sockets for UID: %s", familyName, strerror(-ret));
    391             return ret;
    392         }
    393         if (int ret = readDiagMsg(proto, shouldDestroy)) {
    394             ALOGE("Failed to destroy %s sockets for UID: %s", familyName, strerror(-ret));
    395             return ret;
    396         }
    397     }
    398 
    399     if (mSocketsDestroyed > 0) {
    400         ALOGI("Destroyed %d sockets for UID in %.1f ms", mSocketsDestroyed, s.timeTaken());
    401     }
    402 
    403     return 0;
    404 }
    405 
    406 int SockDiag::destroySockets(const UidRanges& uidRanges, const std::set<uid_t>& skipUids,
    407                              bool excludeLoopback) {
    408     mSocketsDestroyed = 0;
    409     Stopwatch s;
    410 
    411     auto shouldDestroy = [&] (uint8_t, const inet_diag_msg *msg) {
    412         return msg != nullptr &&
    413                uidRanges.hasUid(msg->idiag_uid) &&
    414                skipUids.find(msg->idiag_uid) == skipUids.end() &&
    415                !(excludeLoopback && isLoopbackSocket(msg));
    416     };
    417 
    418     iovec iov[] = {
    419         { nullptr, 0 },
    420     };
    421 
    422     if (int ret = destroyLiveSockets(shouldDestroy, "UID", iov, ARRAY_SIZE(iov))) {
    423         return ret;
    424     }
    425 
    426     std::vector<uid_t> skipUidStrings;
    427     for (uid_t uid : skipUids) {
    428         skipUidStrings.push_back(uid);
    429     }
    430     std::sort(skipUidStrings.begin(), skipUidStrings.end());
    431 
    432     if (mSocketsDestroyed > 0) {
    433         ALOGI("Destroyed %d sockets for %s skip={%s} in %.1f ms",
    434               mSocketsDestroyed, uidRanges.toString().c_str(),
    435               android::base::Join(skipUidStrings, " ").c_str(), s.timeTaken());
    436     }
    437 
    438     return 0;
    439 }
    440 
    441 // Destroys all "live" (CONNECTED, SYN_SENT, SYN_RECV) TCP sockets on the specified netId where:
    442 // 1. The opening app no longer has permission to use this network, or:
    443 // 2. The opening app does have permission, but did not explicitly select this network.
    444 //
    445 // We destroy sockets without the explicit bit because we want to avoid the situation where a
    446 // privileged app uses its privileges without knowing it is doing so. For example, a privileged app
    447 // might have opened a socket on this network just because it was the default network at the
    448 // time. If we don't kill these sockets, those apps could continue to use them without realizing
    449 // that they are now sending and receiving traffic on a network that is now restricted.
    450 int SockDiag::destroySocketsLackingPermission(unsigned netId, Permission permission,
    451                                               bool excludeLoopback) {
    452     struct markmatch {
    453         inet_diag_bc_op op;
    454         // TODO: switch to inet_diag_markcond
    455         __u32 mark;
    456         __u32 mask;
    457     } __attribute__((packed));
    458     constexpr uint8_t matchlen = sizeof(markmatch);
    459 
    460     Fwmark netIdMark, netIdMask;
    461     netIdMark.netId = netId;
    462     netIdMask.netId = 0xffff;
    463 
    464     Fwmark controlMark;
    465     controlMark.explicitlySelected = true;
    466     controlMark.permission = permission;
    467 
    468     // A SOCK_DIAG bytecode program that accepts the sockets we intend to destroy.
    469     struct bytecode {
    470         markmatch netIdMatch;
    471         markmatch controlMatch;
    472         inet_diag_bc_op controlJump;
    473     } __attribute__((packed)) bytecode;
    474 
    475     // The length of the INET_DIAG_BC_JMP instruction.
    476     constexpr uint8_t jmplen = sizeof(inet_diag_bc_op);
    477     // Jump exactly this far past the end of the program to reject.
    478     constexpr uint8_t rejectoffset = sizeof(inet_diag_bc_op);
    479     // Total length of the program.
    480     constexpr uint8_t bytecodelen = sizeof(bytecode);
    481 
    482     bytecode = (struct bytecode) {
    483         // If netId matches, continue, otherwise, reject (i.e., leave socket alone).
    484         { { INET_DIAG_BC_MARK_COND, matchlen, bytecodelen + rejectoffset },
    485           netIdMark.intValue, netIdMask.intValue },
    486 
    487         // If explicit and permission bits match, go to the JMP below which rejects the socket
    488         // (i.e., we leave it alone). Otherwise, jump to the end of the program, which accepts the
    489         // socket (so we destroy it).
    490         { { INET_DIAG_BC_MARK_COND, matchlen, matchlen + jmplen },
    491           controlMark.intValue, controlMark.intValue },
    492 
    493         // This JMP unconditionally rejects the packet by jumping to the reject target. It is
    494         // necessary to keep the kernel bytecode verifier happy. If we don't have a JMP the bytecode
    495         // is invalid because the target of every no jump must always be reachable by yes jumps.
    496         // Without this JMP, the accept target is not reachable by yes jumps and the program will
    497         // be rejected by the validator.
    498         { INET_DIAG_BC_JMP, jmplen, jmplen + rejectoffset },
    499 
    500         // We have reached the end of the program. Accept the socket, and destroy it below.
    501     };
    502 
    503     struct nlattr nla = {
    504         .nla_type = INET_DIAG_REQ_BYTECODE,
    505         .nla_len = sizeof(struct nlattr) + bytecodelen,
    506     };
    507 
    508     iovec iov[] = {
    509         { nullptr,   0 },
    510         { &nla,      sizeof(nla) },
    511         { &bytecode, bytecodelen },
    512     };
    513 
    514     mSocketsDestroyed = 0;
    515     Stopwatch s;
    516 
    517     auto shouldDestroy = [&] (uint8_t, const inet_diag_msg *msg) {
    518         return msg != nullptr && !(excludeLoopback && isLoopbackSocket(msg));
    519     };
    520 
    521     if (int ret = destroyLiveSockets(shouldDestroy, "permission change", iov, ARRAY_SIZE(iov))) {
    522         return ret;
    523     }
    524 
    525     if (mSocketsDestroyed > 0) {
    526         ALOGI("Destroyed %d sockets for netId %d permission=%d in %.1f ms",
    527               mSocketsDestroyed, netId, permission, s.timeTaken());
    528     }
    529 
    530     return 0;
    531 }
    532 
    533 }  // namespace net
    534 }  // namespace android
    535