Home | History | Annotate | Download | only in dns_responder
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "dns_tls_frontend.h"
     18 
     19 #include <netdb.h>
     20 #include <stdio.h>
     21 #include <unistd.h>
     22 #include <sys/poll.h>
     23 #include <sys/socket.h>
     24 #include <sys/types.h>
     25 #include <arpa/inet.h>
     26 #include <openssl/err.h>
     27 #include <openssl/evp.h>
     28 #include <openssl/ssl.h>
     29 
     30 #define LOG_TAG "DnsTlsFrontend"
     31 #include <log/log.h>
     32 
     33 #include <unistd.h>
     34 
     35 namespace {
     36 
     37 const int SHA256_SIZE = 32;
     38 
     39 // Copied from DnsTlsTransport.
     40 bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
     41     int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), NULL);
     42     unsigned char spki[spki_len];
     43     unsigned char* temp = spki;
     44     if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) {
     45         ALOGE("SPKI length mismatch");
     46         return false;
     47     }
     48     out->resize(SHA256_SIZE);
     49     unsigned int digest_len = 0;
     50     int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), NULL);
     51     if (ret != 1) {
     52         ALOGE("Server cert digest extraction failed");
     53         return false;
     54     }
     55     if (digest_len != out->size()) {
     56         ALOGE("Wrong digest length: %d", digest_len);
     57         return false;
     58     }
     59     return true;
     60 }
     61 
     62 std::string errno2str() {
     63     char error_msg[512] = { 0 };
     64     if (strerror_r(errno, error_msg, sizeof(error_msg)))
     65         return std::string();
     66     return std::string(error_msg);
     67 }
     68 
     69 #define APLOGI(fmt, ...) ALOGI(fmt ": [%d] %s", __VA_ARGS__, errno, errno2str().c_str())
     70 
     71 std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
     72     char host_str[NI_MAXHOST] = { 0 };
     73     int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0,
     74                          NI_NUMERICHOST);
     75     if (rv == 0) return std::string(host_str);
     76     return std::string();
     77 }
     78 
     79 bssl::UniquePtr<EVP_PKEY> make_private_key() {
     80     bssl::UniquePtr<BIGNUM> e(BN_new());
     81     if (!e) {
     82         ALOGE("BN_new failed");
     83         return nullptr;
     84     }
     85     if (!BN_set_word(e.get(), RSA_F4)) {
     86         ALOGE("BN_set_word failed");
     87         return nullptr;
     88     }
     89 
     90     bssl::UniquePtr<RSA> rsa(RSA_new());
     91     if (!rsa) {
     92         ALOGE("RSA_new failed");
     93         return nullptr;
     94     }
     95     if (!RSA_generate_key_ex(rsa.get(), 2048, e.get(), NULL)) {
     96         ALOGE("RSA_generate_key_ex failed");
     97         return nullptr;
     98     }
     99 
    100     bssl::UniquePtr<EVP_PKEY> privkey(EVP_PKEY_new());
    101     if (!privkey) {
    102         ALOGE("EVP_PKEY_new failed");
    103         return nullptr;
    104     }
    105     if(!EVP_PKEY_assign_RSA(privkey.get(), rsa.get())) {
    106         ALOGE("EVP_PKEY_assign_RSA failed");
    107         return nullptr;
    108     }
    109 
    110     // |rsa| is now owned by |privkey|, so no need to free it.
    111     rsa.release();
    112     return privkey;
    113 }
    114 
    115 bssl::UniquePtr<X509> make_cert(EVP_PKEY* privkey) {
    116     bssl::UniquePtr<X509> cert(X509_new());
    117     if (!cert) {
    118         ALOGE("X509_new failed");
    119         return nullptr;
    120     }
    121 
    122     ASN1_INTEGER_set(X509_get_serialNumber(cert.get()), 1);
    123 
    124     // Set one hour expiration.
    125     X509_gmtime_adj(X509_get_notBefore(cert.get()), 0);
    126     X509_gmtime_adj(X509_get_notAfter(cert.get()), 60 * 60);
    127 
    128     X509_set_pubkey(cert.get(), privkey);
    129 
    130     if (!X509_sign(cert.get(), privkey, EVP_sha256())) {
    131         ALOGE("X509_sign failed");
    132         return nullptr;
    133     }
    134 
    135     return cert;
    136 }
    137 
    138 }
    139 
    140 namespace test {
    141 
    142 bool DnsTlsFrontend::startServer() {
    143     SSL_load_error_strings();
    144     OpenSSL_add_ssl_algorithms();
    145 
    146     ctx_.reset(SSL_CTX_new(TLS_server_method()));
    147     if (!ctx_) {
    148         ALOGE("SSL context creation failed");
    149         return false;
    150     }
    151 
    152     SSL_CTX_set_ecdh_auto(ctx_.get(), 1);
    153 
    154     bssl::UniquePtr<EVP_PKEY> key(make_private_key());
    155     bssl::UniquePtr<X509> cert(make_cert(key.get()));
    156     if (SSL_CTX_use_certificate(ctx_.get(), cert.get()) <= 0) {
    157         ALOGE("SSL_CTX_use_certificate failed");
    158         return false;
    159     }
    160 
    161     if (!getSPKIDigest(cert.get(), &fingerprint_)) {
    162         ALOGE("getSPKIDigest failed");
    163         return false;
    164     }
    165 
    166     if (SSL_CTX_use_PrivateKey(ctx_.get(), key.get()) <= 0 ) {
    167         ALOGE("SSL_CTX_use_PrivateKey failed");
    168         return false;
    169     }
    170 
    171     // Set up TCP server socket for clients.
    172     addrinfo frontend_ai_hints{
    173         .ai_family = AF_UNSPEC,
    174         .ai_socktype = SOCK_STREAM,
    175         .ai_flags = AI_PASSIVE
    176     };
    177     addrinfo* frontend_ai_res;
    178     int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(),
    179                          &frontend_ai_hints, &frontend_ai_res);
    180     if (rv) {
    181         ALOGE("frontend getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
    182             listen_service_.c_str(), gai_strerror(rv));
    183         return false;
    184     }
    185 
    186     int s = -1;
    187     for (const addrinfo* ai = frontend_ai_res ; ai ; ai = ai->ai_next) {
    188         s = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
    189         if (s < 0) continue;
    190         const int one = 1;
    191         setsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
    192         if (bind(s, ai->ai_addr, ai->ai_addrlen)) {
    193             APLOGI("bind failed for socket %d", s);
    194             close(s);
    195             s = -1;
    196             continue;
    197         }
    198         std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
    199         ALOGI("bound to TCP %s:%s", host_str.c_str(), listen_service_.c_str());
    200         break;
    201     }
    202     freeaddrinfo(frontend_ai_res);
    203     if (s < 0) {
    204         ALOGE("server socket creation failed");
    205         return false;
    206     }
    207 
    208     if (listen(s, 1) < 0) {
    209         ALOGE("listen failed");
    210         return false;
    211     }
    212 
    213     socket_ = s;
    214 
    215     // Set up UDP client socket to backend.
    216     addrinfo backend_ai_hints{
    217         .ai_family = AF_UNSPEC,
    218         .ai_socktype = SOCK_DGRAM
    219     };
    220     addrinfo* backend_ai_res;
    221     rv = getaddrinfo(backend_address_.c_str(), backend_service_.c_str(),
    222                          &backend_ai_hints, &backend_ai_res);
    223     if (rv) {
    224         ALOGE("backend getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
    225             listen_service_.c_str(), gai_strerror(rv));
    226         return false;
    227     }
    228     backend_socket_ = socket(backend_ai_res->ai_family, backend_ai_res->ai_socktype,
    229         backend_ai_res->ai_protocol);
    230     if (backend_socket_ < 0) {
    231         ALOGE("backend socket creation failed");
    232         return false;
    233     }
    234     connect(backend_socket_, backend_ai_res->ai_addr, backend_ai_res->ai_addrlen);
    235     freeaddrinfo(backend_ai_res);
    236 
    237     {
    238         std::lock_guard<std::mutex> lock(update_mutex_);
    239         handler_thread_ = std::thread(&DnsTlsFrontend::requestHandler, this);
    240     }
    241     ALOGI("server started successfully");
    242     return true;
    243 }
    244 
    245 void DnsTlsFrontend::requestHandler() {
    246     ALOGD("Request handler started");
    247     struct pollfd fds[1] = {{ .fd = socket_, .events = POLLIN }};
    248 
    249     while (!terminate_) {
    250         int poll_code = poll(fds, 1, 10 /* ms */);
    251         if (poll_code == 0) {
    252             // Timeout.  Poll again.
    253             continue;
    254         } else if (poll_code < 0) {
    255             ALOGW("Poll failed with error %d", poll_code);
    256             // Error.
    257             break;
    258         }
    259         sockaddr_storage addr;
    260         socklen_t len = sizeof(addr);
    261 
    262         ALOGD("Trying to accept a client");
    263         int client = accept(socket_, reinterpret_cast<sockaddr*>(&addr), &len);
    264         ALOGD("Got client socket %d", client);
    265         if (client < 0) {
    266             // Stop
    267             break;
    268         }
    269 
    270         bssl::UniquePtr<SSL> ssl(SSL_new(ctx_.get()));
    271         SSL_set_fd(ssl.get(), client);
    272 
    273         ALOGD("Doing SSL handshake");
    274         bool success = false;
    275         if (SSL_accept(ssl.get()) <= 0) {
    276             ALOGI("SSL negotiation failure");
    277         } else {
    278             ALOGD("SSL handshake complete");
    279             success = handleOneRequest(ssl.get());
    280         }
    281 
    282         close(client);
    283 
    284         if (success) {
    285             // Increment queries_ as late as possible, because it represents
    286             // a query that is fully processed, and the response returned to the
    287             // client, including cleanup actions.
    288             ++queries_;
    289         }
    290     }
    291     ALOGD("Request handler terminating");
    292 }
    293 
    294 bool DnsTlsFrontend::handleOneRequest(SSL* ssl) {
    295     uint8_t queryHeader[2];
    296     if (SSL_read(ssl, &queryHeader, 2) != 2) {
    297         ALOGI("Not enough header bytes");
    298         return false;
    299     }
    300     const uint16_t qlen = (queryHeader[0] << 8) | queryHeader[1];
    301     uint8_t query[qlen];
    302     if (SSL_read(ssl, &query, qlen) != qlen) {
    303         ALOGI("Not enough query bytes");
    304         return false;
    305     }
    306     int sent = send(backend_socket_, query, qlen, 0);
    307     if (sent != qlen) {
    308         ALOGI("Failed to send query");
    309         return false;
    310     }
    311     const int max_size = 4096;
    312     uint8_t recv_buffer[max_size];
    313     int rlen = recv(backend_socket_, recv_buffer, max_size, 0);
    314     if (rlen <= 0) {
    315         ALOGI("Failed to receive response");
    316         return false;
    317     }
    318     uint8_t responseHeader[2];
    319     responseHeader[0] = rlen >> 8;
    320     responseHeader[1] = rlen;
    321     if (SSL_write(ssl, responseHeader, 2) != 2) {
    322         ALOGI("Failed to write response header");
    323         return false;
    324     }
    325     if (SSL_write(ssl, recv_buffer, rlen) != rlen) {
    326         ALOGI("Failed to write response body");
    327         return false;
    328     }
    329     return true;
    330 }
    331 
    332 bool DnsTlsFrontend::stopServer() {
    333     std::lock_guard<std::mutex> lock(update_mutex_);
    334     if (!running()) {
    335         ALOGI("server not running");
    336         return false;
    337     }
    338     if (terminate_) {
    339         ALOGI("LOGIC ERROR");
    340         return false;
    341     }
    342     ALOGI("stopping frontend");
    343     terminate_ = true;
    344     handler_thread_.join();
    345     close(socket_);
    346     close(backend_socket_);
    347     terminate_ = false;
    348     socket_ = -1;
    349     backend_socket_ = -1;
    350     ctx_.reset();
    351     fingerprint_.clear();
    352     ALOGI("frontend stopped successfully");
    353     return true;
    354 }
    355 
    356 bool DnsTlsFrontend::waitForQueries(int number, int timeoutMs) const {
    357     constexpr int intervalMs = 20;
    358     int limit = timeoutMs / intervalMs;
    359     for (int count = 0; count <= limit; ++count) {
    360         bool done = queries_ >= number;
    361         // Always sleep at least one more interval after we are done, to wait for
    362         // any immediate post-query actions that the client may take (such as
    363         // marking this server as reachable during validation).
    364         usleep(intervalMs * 1000);
    365         if (done) {
    366             // For ensuring that calls have sufficient headroom for slow machines
    367             ALOGD("Query arrived in %d/%d of allotted time", count, limit);
    368             return true;
    369         }
    370     }
    371     return false;
    372 }
    373 
    374 }  // namespace test
    375