1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "dns/DnsTlsTransport.h" 18 19 #include <arpa/inet.h> 20 #include <arpa/nameser.h> 21 #include <errno.h> 22 #include <openssl/err.h> 23 #include <openssl/ssl.h> 24 #include <stdlib.h> 25 26 #define LOG_TAG "DnsTlsTransport" 27 #define DBG 0 28 29 #include "log/log.h" 30 #include "Fwmark.h" 31 #undef ADD // already defined in nameser.h 32 #include "NetdConstants.h" 33 #include "Permission.h" 34 35 36 namespace android { 37 namespace net { 38 39 namespace { 40 41 bool setNonBlocking(int fd, bool enabled) { 42 int flags = fcntl(fd, F_GETFL); 43 if (flags < 0) return false; 44 45 if (enabled) { 46 flags |= O_NONBLOCK; 47 } else { 48 flags &= ~O_NONBLOCK; 49 } 50 return (fcntl(fd, F_SETFL, flags) == 0); 51 } 52 53 int waitForReading(int fd) { 54 fd_set fds; 55 FD_ZERO(&fds); 56 FD_SET(fd, &fds); 57 const int ret = TEMP_FAILURE_RETRY(select(fd + 1, &fds, nullptr, nullptr, nullptr)); 58 if (DBG && ret <= 0) { 59 ALOGD("select"); 60 } 61 return ret; 62 } 63 64 int waitForWriting(int fd) { 65 fd_set fds; 66 FD_ZERO(&fds); 67 FD_SET(fd, &fds); 68 const int ret = TEMP_FAILURE_RETRY(select(fd + 1, nullptr, &fds, nullptr, nullptr)); 69 if (DBG && ret <= 0) { 70 ALOGD("select"); 71 } 72 return ret; 73 } 74 75 } // namespace 76 77 android::base::unique_fd DnsTlsTransport::makeConnectedSocket() const { 78 android::base::unique_fd fd; 79 int type = SOCK_NONBLOCK | SOCK_CLOEXEC; 80 switch (mProtocol) { 81 case IPPROTO_TCP: 82 type |= SOCK_STREAM; 83 break; 84 default: 85 errno = EPROTONOSUPPORT; 86 return fd; 87 } 88 89 fd.reset(socket(mAddr.ss_family, type, mProtocol)); 90 if (fd.get() == -1) { 91 return fd; 92 } 93 94 const socklen_t len = sizeof(mMark); 95 if (setsockopt(fd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) { 96 fd.reset(); 97 } else if (connect(fd.get(), 98 reinterpret_cast<const struct sockaddr *>(&mAddr), sizeof(mAddr)) != 0 99 && errno != EINPROGRESS) { 100 fd.reset(); 101 } 102 103 return fd; 104 } 105 106 bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) { 107 int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), NULL); 108 unsigned char spki[spki_len]; 109 unsigned char* temp = spki; 110 if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) { 111 ALOGW("SPKI length mismatch"); 112 return false; 113 } 114 out->resize(SHA256_SIZE); 115 unsigned int digest_len = 0; 116 int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), NULL); 117 if (ret != 1) { 118 ALOGW("Server cert digest extraction failed"); 119 return false; 120 } 121 if (digest_len != out->size()) { 122 ALOGW("Wrong digest length: %d", digest_len); 123 return false; 124 } 125 return true; 126 } 127 128 SSL* DnsTlsTransport::sslConnect(int fd) { 129 if (fd < 0) { 130 ALOGD("%u makeConnectedSocket() failed with: %s", mMark, strerror(errno)); 131 return nullptr; 132 } 133 134 // Set up TLS context. 135 bssl::UniquePtr<SSL_CTX> ssl_ctx(SSL_CTX_new(TLS_method())); 136 if (!SSL_CTX_set_max_proto_version(ssl_ctx.get(), TLS1_3_VERSION) || 137 !SSL_CTX_set_min_proto_version(ssl_ctx.get(), TLS1_1_VERSION)) { 138 ALOGD("failed to min/max TLS versions"); 139 return nullptr; 140 } 141 142 bssl::UniquePtr<SSL> ssl(SSL_new(ssl_ctx.get())); 143 bssl::UniquePtr<BIO> bio(BIO_new_socket(fd, BIO_CLOSE)); 144 SSL_set_bio(ssl.get(), bio.get(), bio.get()); 145 bio.release(); 146 147 if (!setNonBlocking(fd, false)) { 148 ALOGE("Failed to disable nonblocking status on DNS-over-TLS fd"); 149 return nullptr; 150 } 151 152 for (;;) { 153 if (DBG) { 154 ALOGD("%u Calling SSL_connect", mMark); 155 } 156 int ret = SSL_connect(ssl.get()); 157 if (DBG) { 158 ALOGD("%u SSL_connect returned %d", mMark, ret); 159 } 160 if (ret == 1) break; // SSL handshake complete; 161 162 const int ssl_err = SSL_get_error(ssl.get(), ret); 163 switch (ssl_err) { 164 case SSL_ERROR_WANT_READ: 165 if (waitForReading(fd) != 1) { 166 ALOGW("SSL_connect read error"); 167 return nullptr; 168 } 169 break; 170 case SSL_ERROR_WANT_WRITE: 171 if (waitForWriting(fd) != 1) { 172 ALOGW("SSL_connect write error"); 173 return nullptr; 174 } 175 break; 176 default: 177 ALOGW("SSL_connect error %d, errno=%d", ssl_err, errno); 178 return nullptr; 179 } 180 } 181 182 if (!mFingerprints.empty()) { 183 if (DBG) { 184 ALOGD("Checking DNS over TLS fingerprint"); 185 } 186 // TODO: Follow the cert chain and check all the way up. 187 bssl::UniquePtr<X509> cert(SSL_get_peer_certificate(ssl.get())); 188 if (!cert) { 189 ALOGW("Server has null certificate"); 190 return nullptr; 191 } 192 std::vector<uint8_t> digest; 193 if (!getSPKIDigest(cert.get(), &digest)) { 194 ALOGE("Digest computation failed"); 195 return nullptr; 196 } 197 198 if (mFingerprints.count(digest) == 0) { 199 ALOGW("No matching fingerprint"); 200 return nullptr; 201 } 202 if (DBG) { 203 ALOGD("DNS over TLS fingerprint is correct"); 204 } 205 } 206 207 if (DBG) { 208 ALOGD("%u handshake complete", mMark); 209 } 210 return ssl.release(); 211 } 212 213 bool DnsTlsTransport::sslWrite(int fd, SSL *ssl, const uint8_t *buffer, int len) { 214 if (DBG) { 215 ALOGD("%u Writing %d bytes", mMark, len); 216 } 217 for (;;) { 218 int ret = SSL_write(ssl, buffer, len); 219 if (ret == len) break; // SSL write complete; 220 221 if (ret < 1) { 222 const int ssl_err = SSL_get_error(ssl, ret); 223 switch (ssl_err) { 224 case SSL_ERROR_WANT_WRITE: 225 if (waitForWriting(fd) != 1) { 226 if (DBG) { 227 ALOGW("SSL_write error"); 228 } 229 return false; 230 } 231 continue; 232 case 0: 233 break; // SSL write complete; 234 default: 235 if (DBG) { 236 ALOGW("SSL_write error %d", ssl_err); 237 } 238 return false; 239 } 240 } 241 } 242 if (DBG) { 243 ALOGD("%u Wrote %d bytes", mMark, len); 244 } 245 return true; 246 } 247 248 // Read exactly len bytes into buffer or fail 249 bool DnsTlsTransport::sslRead(int fd, SSL *ssl, uint8_t *buffer, int len) { 250 int remaining = len; 251 while (remaining > 0) { 252 int ret = SSL_read(ssl, buffer + (len - remaining), remaining); 253 if (ret == 0) { 254 ALOGE("SSL socket closed with %i of %i bytes remaining", remaining, len); 255 return false; 256 } 257 258 if (ret < 0) { 259 const int ssl_err = SSL_get_error(ssl, ret); 260 if (ssl_err == SSL_ERROR_WANT_READ) { 261 if (waitForReading(fd) != 1) { 262 if (DBG) { 263 ALOGW("SSL_read error"); 264 } 265 return false; 266 } 267 continue; 268 } else { 269 if (DBG) { 270 ALOGW("SSL_read error %d", ssl_err); 271 } 272 return false; 273 } 274 } 275 276 remaining -= ret; 277 } 278 return true; 279 } 280 281 DnsTlsTransport::Response DnsTlsTransport::doQuery(const uint8_t *query, size_t qlen, 282 uint8_t *response, size_t limit, int *resplen) { 283 *resplen = 0; // Zero indicates an error. 284 285 if (DBG) { 286 ALOGD("%u connecting TCP socket", mMark); 287 } 288 android::base::unique_fd fd(makeConnectedSocket()); 289 if (DBG) { 290 ALOGD("%u connecting SSL", mMark); 291 } 292 bssl::UniquePtr<SSL> ssl(sslConnect(fd)); 293 if (ssl == nullptr) { 294 if (DBG) { 295 ALOGW("%u SSL connection failed", mMark); 296 } 297 return Response::network_error; 298 } 299 300 uint8_t queryHeader[2]; 301 queryHeader[0] = qlen >> 8; 302 queryHeader[1] = qlen; 303 if (!sslWrite(fd.get(), ssl.get(), queryHeader, 2)) { 304 return Response::network_error; 305 } 306 if (!sslWrite(fd.get(), ssl.get(), query, qlen)) { 307 return Response::network_error; 308 } 309 if (DBG) { 310 ALOGD("%u SSL_write complete", mMark); 311 } 312 313 uint8_t responseHeader[2]; 314 if (!sslRead(fd.get(), ssl.get(), responseHeader, 2)) { 315 if (DBG) { 316 ALOGW("%u Failed to read 2-byte length header", mMark); 317 } 318 return Response::network_error; 319 } 320 const uint16_t responseSize = (responseHeader[0] << 8) | responseHeader[1]; 321 if (DBG) { 322 ALOGD("%u Expecting response of size %i", mMark, responseSize); 323 } 324 if (responseSize > limit) { 325 ALOGE("%u Response doesn't fit in output buffer: %i", mMark, responseSize); 326 return Response::limit_error; 327 } 328 if (!sslRead(fd.get(), ssl.get(), response, responseSize)) { 329 if (DBG) { 330 ALOGW("%u Failed to read %i bytes", mMark, responseSize); 331 } 332 return Response::network_error; 333 } 334 if (DBG) { 335 ALOGD("%u SSL_read complete", mMark); 336 } 337 338 if (response[0] != query[0] || response[1] != query[1]) { 339 ALOGE("reply query ID != query ID"); 340 return Response::internal_error; 341 } 342 343 SSL_shutdown(ssl.get()); 344 345 *resplen = responseSize; 346 return Response::success; 347 } 348 349 bool validateDnsTlsServer(unsigned netid, const struct sockaddr_storage& ss, 350 const std::set<std::vector<uint8_t>>& fingerprints) { 351 if (DBG) { 352 ALOGD("Beginning validation on %u", netid); 353 } 354 // Generate "<random>-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in 355 // order to prove that it is actually a working DNS over TLS server. 356 static const char kDnsSafeChars[] = 357 "abcdefhijklmnopqrstuvwxyz" 358 "ABCDEFHIJKLMNOPQRSTUVWXYZ" 359 "0123456789"; 360 const auto c = [](uint8_t rnd) -> uint8_t { 361 return kDnsSafeChars[(rnd % ARRAY_SIZE(kDnsSafeChars))]; 362 }; 363 uint8_t rnd[8]; 364 arc4random_buf(rnd, ARRAY_SIZE(rnd)); 365 // We could try to use res_mkquery() here, but it's basically the same. 366 uint8_t query[] = { 367 rnd[6], rnd[7], // [0-1] query ID 368 1, 0, // [2-3] flags; query[2] = 1 for recursion desired (RD). 369 0, 1, // [4-5] QDCOUNT (number of queries) 370 0, 0, // [6-7] ANCOUNT (number of answers) 371 0, 0, // [8-9] NSCOUNT (number of name server records) 372 0, 0, // [10-11] ARCOUNT (number of additional records) 373 17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]), 374 '-', 'd', 'n', 's', 'o', 't', 'l', 's', '-', 'd', 's', 375 6, 'm', 'e', 't', 'r', 'i', 'c', 376 7, 'g', 's', 't', 'a', 't', 'i', 'c', 377 3, 'c', 'o', 'm', 378 0, // null terminator of FQDN (root TLD) 379 0, ns_t_aaaa, // QTYPE 380 0, ns_c_in // QCLASS 381 }; 382 const int qlen = ARRAY_SIZE(query); 383 384 const int kRecvBufSize = 4 * 1024; 385 uint8_t recvbuf[kRecvBufSize]; 386 387 // At validation time, we only know the netId, so we have to guess/compute the 388 // corresponding socket mark. 389 Fwmark fwmark; 390 fwmark.permission = PERMISSION_SYSTEM; 391 fwmark.explicitlySelected = true; 392 fwmark.protectedFromVpn = true; 393 fwmark.netId = netid; 394 unsigned mark = fwmark.intValue; 395 DnsTlsTransport xport(mark, IPPROTO_TCP, ss, fingerprints); 396 int replylen = 0; 397 xport.doQuery(query, qlen, recvbuf, kRecvBufSize, &replylen); 398 if (replylen == 0) { 399 if (DBG) { 400 ALOGD("doQuery failed"); 401 } 402 return false; 403 } 404 405 if (replylen < NS_HFIXEDSZ) { 406 if (DBG) { 407 ALOGW("short response: %d", replylen); 408 } 409 return false; 410 } 411 412 const int qdcount = (recvbuf[4] << 8) | recvbuf[5]; 413 if (qdcount != 1) { 414 ALOGW("reply query count != 1: %d", qdcount); 415 return false; 416 } 417 418 const int ancount = (recvbuf[6] << 8) | recvbuf[7]; 419 if (DBG) { 420 ALOGD("%u answer count: %d", netid, ancount); 421 } 422 423 // TODO: Further validate the response contents (check for valid AAAA record, ...). 424 // Note that currently, integration tests rely on this function accepting a 425 // response with zero records. 426 #if 0 427 for (int i = 0; i < resplen; i++) { 428 ALOGD("recvbuf[%d] = %d %c", i, recvbuf[i], recvbuf[i]); 429 } 430 #endif 431 return true; 432 } 433 434 } // namespace net 435 } // namespace android 436