Home | History | Annotate | Download | only in dns
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "dns/DnsTlsTransport.h"
     18 
     19 #include <arpa/inet.h>
     20 #include <arpa/nameser.h>
     21 #include <errno.h>
     22 #include <openssl/err.h>
     23 #include <openssl/ssl.h>
     24 #include <stdlib.h>
     25 
     26 #define LOG_TAG "DnsTlsTransport"
     27 #define DBG 0
     28 
     29 #include "log/log.h"
     30 #include "Fwmark.h"
     31 #undef ADD  // already defined in nameser.h
     32 #include "NetdConstants.h"
     33 #include "Permission.h"
     34 
     35 
     36 namespace android {
     37 namespace net {
     38 
     39 namespace {
     40 
     41 bool setNonBlocking(int fd, bool enabled) {
     42     int flags = fcntl(fd, F_GETFL);
     43     if (flags < 0) return false;
     44 
     45     if (enabled) {
     46         flags |= O_NONBLOCK;
     47     } else {
     48         flags &= ~O_NONBLOCK;
     49     }
     50     return (fcntl(fd, F_SETFL, flags) == 0);
     51 }
     52 
     53 int waitForReading(int fd) {
     54     fd_set fds;
     55     FD_ZERO(&fds);
     56     FD_SET(fd, &fds);
     57     const int ret = TEMP_FAILURE_RETRY(select(fd + 1, &fds, nullptr, nullptr, nullptr));
     58     if (DBG && ret <= 0) {
     59         ALOGD("select");
     60     }
     61     return ret;
     62 }
     63 
     64 int waitForWriting(int fd) {
     65     fd_set fds;
     66     FD_ZERO(&fds);
     67     FD_SET(fd, &fds);
     68     const int ret = TEMP_FAILURE_RETRY(select(fd + 1, nullptr, &fds, nullptr, nullptr));
     69     if (DBG && ret <= 0) {
     70         ALOGD("select");
     71     }
     72     return ret;
     73 }
     74 
     75 }  // namespace
     76 
     77 android::base::unique_fd DnsTlsTransport::makeConnectedSocket() const {
     78     android::base::unique_fd fd;
     79     int type = SOCK_NONBLOCK | SOCK_CLOEXEC;
     80     switch (mProtocol) {
     81         case IPPROTO_TCP:
     82             type |= SOCK_STREAM;
     83             break;
     84         default:
     85             errno = EPROTONOSUPPORT;
     86             return fd;
     87     }
     88 
     89     fd.reset(socket(mAddr.ss_family, type, mProtocol));
     90     if (fd.get() == -1) {
     91         return fd;
     92     }
     93 
     94     const socklen_t len = sizeof(mMark);
     95     if (setsockopt(fd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) {
     96         fd.reset();
     97     } else if (connect(fd.get(),
     98             reinterpret_cast<const struct sockaddr *>(&mAddr), sizeof(mAddr)) != 0
     99         && errno != EINPROGRESS) {
    100         fd.reset();
    101     }
    102 
    103     return fd;
    104 }
    105 
    106 bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
    107     int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), NULL);
    108     unsigned char spki[spki_len];
    109     unsigned char* temp = spki;
    110     if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) {
    111         ALOGW("SPKI length mismatch");
    112         return false;
    113     }
    114     out->resize(SHA256_SIZE);
    115     unsigned int digest_len = 0;
    116     int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), NULL);
    117     if (ret != 1) {
    118         ALOGW("Server cert digest extraction failed");
    119         return false;
    120     }
    121     if (digest_len != out->size()) {
    122         ALOGW("Wrong digest length: %d", digest_len);
    123         return false;
    124     }
    125     return true;
    126 }
    127 
    128 SSL* DnsTlsTransport::sslConnect(int fd) {
    129     if (fd < 0) {
    130         ALOGD("%u makeConnectedSocket() failed with: %s", mMark, strerror(errno));
    131         return nullptr;
    132     }
    133 
    134     // Set up TLS context.
    135     bssl::UniquePtr<SSL_CTX> ssl_ctx(SSL_CTX_new(TLS_method()));
    136     if (!SSL_CTX_set_max_proto_version(ssl_ctx.get(), TLS1_3_VERSION) ||
    137         !SSL_CTX_set_min_proto_version(ssl_ctx.get(), TLS1_1_VERSION)) {
    138         ALOGD("failed to min/max TLS versions");
    139         return nullptr;
    140     }
    141 
    142     bssl::UniquePtr<SSL> ssl(SSL_new(ssl_ctx.get()));
    143     bssl::UniquePtr<BIO> bio(BIO_new_socket(fd, BIO_CLOSE));
    144     SSL_set_bio(ssl.get(), bio.get(), bio.get());
    145     bio.release();
    146 
    147     if (!setNonBlocking(fd, false)) {
    148         ALOGE("Failed to disable nonblocking status on DNS-over-TLS fd");
    149         return nullptr;
    150     }
    151 
    152     for (;;) {
    153         if (DBG) {
    154             ALOGD("%u Calling SSL_connect", mMark);
    155         }
    156         int ret = SSL_connect(ssl.get());
    157         if (DBG) {
    158             ALOGD("%u SSL_connect returned %d", mMark, ret);
    159         }
    160         if (ret == 1) break;  // SSL handshake complete;
    161 
    162         const int ssl_err = SSL_get_error(ssl.get(), ret);
    163         switch (ssl_err) {
    164             case SSL_ERROR_WANT_READ:
    165                 if (waitForReading(fd) != 1) {
    166                     ALOGW("SSL_connect read error");
    167                     return nullptr;
    168                 }
    169                 break;
    170             case SSL_ERROR_WANT_WRITE:
    171                 if (waitForWriting(fd) != 1) {
    172                     ALOGW("SSL_connect write error");
    173                     return nullptr;
    174                 }
    175                 break;
    176             default:
    177                 ALOGW("SSL_connect error %d, errno=%d", ssl_err, errno);
    178                 return nullptr;
    179         }
    180     }
    181 
    182     if (!mFingerprints.empty()) {
    183         if (DBG) {
    184             ALOGD("Checking DNS over TLS fingerprint");
    185         }
    186         // TODO: Follow the cert chain and check all the way up.
    187         bssl::UniquePtr<X509> cert(SSL_get_peer_certificate(ssl.get()));
    188         if (!cert) {
    189             ALOGW("Server has null certificate");
    190             return nullptr;
    191         }
    192         std::vector<uint8_t> digest;
    193         if (!getSPKIDigest(cert.get(), &digest)) {
    194             ALOGE("Digest computation failed");
    195             return nullptr;
    196         }
    197 
    198         if (mFingerprints.count(digest) == 0) {
    199             ALOGW("No matching fingerprint");
    200             return nullptr;
    201         }
    202         if (DBG) {
    203             ALOGD("DNS over TLS fingerprint is correct");
    204         }
    205     }
    206 
    207     if (DBG) {
    208         ALOGD("%u handshake complete", mMark);
    209     }
    210     return ssl.release();
    211 }
    212 
    213 bool DnsTlsTransport::sslWrite(int fd, SSL *ssl, const uint8_t *buffer, int len) {
    214     if (DBG) {
    215         ALOGD("%u Writing %d bytes", mMark, len);
    216     }
    217     for (;;) {
    218         int ret = SSL_write(ssl, buffer, len);
    219         if (ret == len) break;  // SSL write complete;
    220 
    221         if (ret < 1) {
    222             const int ssl_err = SSL_get_error(ssl, ret);
    223             switch (ssl_err) {
    224                 case SSL_ERROR_WANT_WRITE:
    225                     if (waitForWriting(fd) != 1) {
    226                         if (DBG) {
    227                             ALOGW("SSL_write error");
    228                         }
    229                         return false;
    230                     }
    231                     continue;
    232                 case 0:
    233                     break;  // SSL write complete;
    234                 default:
    235                     if (DBG) {
    236                         ALOGW("SSL_write error %d", ssl_err);
    237                     }
    238                     return false;
    239             }
    240         }
    241     }
    242     if (DBG) {
    243         ALOGD("%u Wrote %d bytes", mMark, len);
    244     }
    245     return true;
    246 }
    247 
    248 // Read exactly len bytes into buffer or fail
    249 bool DnsTlsTransport::sslRead(int fd, SSL *ssl, uint8_t *buffer, int len) {
    250     int remaining = len;
    251     while (remaining > 0) {
    252         int ret = SSL_read(ssl, buffer + (len - remaining), remaining);
    253         if (ret == 0) {
    254             ALOGE("SSL socket closed with %i of %i bytes remaining", remaining, len);
    255             return false;
    256         }
    257 
    258         if (ret < 0) {
    259             const int ssl_err = SSL_get_error(ssl, ret);
    260             if (ssl_err == SSL_ERROR_WANT_READ) {
    261                 if (waitForReading(fd) != 1) {
    262                     if (DBG) {
    263                         ALOGW("SSL_read error");
    264                     }
    265                     return false;
    266                 }
    267                 continue;
    268             } else {
    269                 if (DBG) {
    270                     ALOGW("SSL_read error %d", ssl_err);
    271                 }
    272                 return false;
    273             }
    274         }
    275 
    276         remaining -= ret;
    277     }
    278     return true;
    279 }
    280 
    281 DnsTlsTransport::Response DnsTlsTransport::doQuery(const uint8_t *query, size_t qlen,
    282         uint8_t *response, size_t limit, int *resplen) {
    283     *resplen = 0;  // Zero indicates an error.
    284 
    285     if (DBG) {
    286         ALOGD("%u connecting TCP socket", mMark);
    287     }
    288     android::base::unique_fd fd(makeConnectedSocket());
    289     if (DBG) {
    290         ALOGD("%u connecting SSL", mMark);
    291     }
    292     bssl::UniquePtr<SSL> ssl(sslConnect(fd));
    293     if (ssl == nullptr) {
    294         if (DBG) {
    295             ALOGW("%u SSL connection failed", mMark);
    296         }
    297         return Response::network_error;
    298     }
    299 
    300     uint8_t queryHeader[2];
    301     queryHeader[0] = qlen >> 8;
    302     queryHeader[1] = qlen;
    303     if (!sslWrite(fd.get(), ssl.get(), queryHeader, 2)) {
    304         return Response::network_error;
    305     }
    306     if (!sslWrite(fd.get(), ssl.get(), query, qlen)) {
    307         return Response::network_error;
    308     }
    309     if (DBG) {
    310         ALOGD("%u SSL_write complete", mMark);
    311     }
    312 
    313     uint8_t responseHeader[2];
    314     if (!sslRead(fd.get(), ssl.get(), responseHeader, 2)) {
    315         if (DBG) {
    316             ALOGW("%u Failed to read 2-byte length header", mMark);
    317         }
    318         return Response::network_error;
    319     }
    320     const uint16_t responseSize = (responseHeader[0] << 8) | responseHeader[1];
    321     if (DBG) {
    322         ALOGD("%u Expecting response of size %i", mMark, responseSize);
    323     }
    324     if (responseSize > limit) {
    325         ALOGE("%u Response doesn't fit in output buffer: %i", mMark, responseSize);
    326         return Response::limit_error;
    327     }
    328     if (!sslRead(fd.get(), ssl.get(), response, responseSize)) {
    329         if (DBG) {
    330             ALOGW("%u Failed to read %i bytes", mMark, responseSize);
    331         }
    332         return Response::network_error;
    333     }
    334     if (DBG) {
    335         ALOGD("%u SSL_read complete", mMark);
    336     }
    337 
    338     if (response[0] != query[0] || response[1] != query[1]) {
    339         ALOGE("reply query ID != query ID");
    340         return Response::internal_error;
    341     }
    342 
    343     SSL_shutdown(ssl.get());
    344 
    345     *resplen = responseSize;
    346     return Response::success;
    347 }
    348 
    349 bool validateDnsTlsServer(unsigned netid, const struct sockaddr_storage& ss,
    350         const std::set<std::vector<uint8_t>>& fingerprints) {
    351     if (DBG) {
    352         ALOGD("Beginning validation on %u", netid);
    353     }
    354     // Generate "<random>-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in
    355     // order to prove that it is actually a working DNS over TLS server.
    356     static const char kDnsSafeChars[] =
    357             "abcdefhijklmnopqrstuvwxyz"
    358             "ABCDEFHIJKLMNOPQRSTUVWXYZ"
    359             "0123456789";
    360     const auto c = [](uint8_t rnd) -> uint8_t {
    361         return kDnsSafeChars[(rnd % ARRAY_SIZE(kDnsSafeChars))];
    362     };
    363     uint8_t rnd[8];
    364     arc4random_buf(rnd, ARRAY_SIZE(rnd));
    365     // We could try to use res_mkquery() here, but it's basically the same.
    366     uint8_t query[] = {
    367         rnd[6], rnd[7],  // [0-1]   query ID
    368         1, 0,  // [2-3]   flags; query[2] = 1 for recursion desired (RD).
    369         0, 1,  // [4-5]   QDCOUNT (number of queries)
    370         0, 0,  // [6-7]   ANCOUNT (number of answers)
    371         0, 0,  // [8-9]   NSCOUNT (number of name server records)
    372         0, 0,  // [10-11] ARCOUNT (number of additional records)
    373         17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]),
    374             '-', 'd', 'n', 's', 'o', 't', 'l', 's', '-', 'd', 's',
    375         6, 'm', 'e', 't', 'r', 'i', 'c',
    376         7, 'g', 's', 't', 'a', 't', 'i', 'c',
    377         3, 'c', 'o', 'm',
    378         0,  // null terminator of FQDN (root TLD)
    379         0, ns_t_aaaa,  // QTYPE
    380         0, ns_c_in     // QCLASS
    381     };
    382     const int qlen = ARRAY_SIZE(query);
    383 
    384     const int kRecvBufSize = 4 * 1024;
    385     uint8_t recvbuf[kRecvBufSize];
    386 
    387     // At validation time, we only know the netId, so we have to guess/compute the
    388     // corresponding socket mark.
    389     Fwmark fwmark;
    390     fwmark.permission = PERMISSION_SYSTEM;
    391     fwmark.explicitlySelected = true;
    392     fwmark.protectedFromVpn = true;
    393     fwmark.netId = netid;
    394     unsigned mark = fwmark.intValue;
    395     DnsTlsTransport xport(mark, IPPROTO_TCP, ss, fingerprints);
    396     int replylen = 0;
    397     xport.doQuery(query, qlen, recvbuf, kRecvBufSize, &replylen);
    398     if (replylen == 0) {
    399         if (DBG) {
    400             ALOGD("doQuery failed");
    401         }
    402         return false;
    403     }
    404 
    405     if (replylen < NS_HFIXEDSZ) {
    406         if (DBG) {
    407             ALOGW("short response: %d", replylen);
    408         }
    409         return false;
    410     }
    411 
    412     const int qdcount = (recvbuf[4] << 8) | recvbuf[5];
    413     if (qdcount != 1) {
    414         ALOGW("reply query count != 1: %d", qdcount);
    415         return false;
    416     }
    417 
    418     const int ancount = (recvbuf[6] << 8) | recvbuf[7];
    419     if (DBG) {
    420         ALOGD("%u answer count: %d", netid, ancount);
    421     }
    422 
    423     // TODO: Further validate the response contents (check for valid AAAA record, ...).
    424     // Note that currently, integration tests rely on this function accepting a
    425     // response with zero records.
    426 #if 0
    427     for (int i = 0; i < resplen; i++) {
    428         ALOGD("recvbuf[%d] = %d %c", i, recvbuf[i], recvbuf[i]);
    429     }
    430 #endif
    431     return true;
    432 }
    433 
    434 }  // namespace net
    435 }  // namespace android
    436