Home | History | Annotate | Download | only in fs
      1 /*
      2  * Copyright (C) 2016 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 #include "common/libs/fs/shared_fd.h"
     17 
     18 #include <sys/types.h>
     19 #include <sys/stat.h>
     20 #include <cstddef>
     21 #include <errno.h>
     22 #include <fcntl.h>
     23 #include <netinet/in.h>
     24 #include <unistd.h>
     25 #include <algorithm>
     26 
     27 #include "common/libs/auto_resources/auto_resources.h"
     28 #include "common/libs/glog/logging.h"
     29 #include "common/libs/fs/shared_select.h"
     30 
     31 // #define ENABLE_GCE_SHARED_FD_LOGGING 1
     32 
     33 namespace {
     34 using cvd::SharedFDSet;
     35 
     36 void MarkAll(const SharedFDSet& input, fd_set* dest, int* max_index) {
     37   for (SharedFDSet::const_iterator it = input.begin(); it != input.end();
     38        ++it) {
     39     (*it)->Set(dest, max_index);
     40   }
     41 }
     42 
     43 void CheckMarked(fd_set* in_out_mask, SharedFDSet* in_out_set) {
     44   if (!in_out_set) {
     45     return;
     46   }
     47   SharedFDSet save;
     48   save.swap(in_out_set);
     49   for (SharedFDSet::iterator it = save.begin(); it != save.end(); ++it) {
     50     if ((*it)->IsSet(in_out_mask)) {
     51       in_out_set->Set(*it);
     52     }
     53   }
     54 }
     55 }  // namespace
     56 
     57 namespace cvd {
     58 
     59 bool FileInstance::CopyFrom(FileInstance& in) {
     60   AutoFreeBuffer buffer;
     61   buffer.Resize(8192);
     62   while (true) {
     63     ssize_t num_read = in.Read(buffer.data(), buffer.size());
     64     if (!num_read) {
     65       return true;
     66     }
     67     if (num_read == -1) {
     68       return false;
     69     }
     70     if (num_read > 0) {
     71       if (Write(buffer.data(), num_read) != num_read) {
     72         // The caller will have to log an appropriate message.
     73         return false;
     74       }
     75     }
     76   }
     77   return true;
     78 }
     79 
     80 bool FileInstance::CopyFrom(FileInstance& in, size_t length) {
     81   AutoFreeBuffer buffer;
     82   buffer.Resize(8192);
     83   while (length > 0) {
     84     ssize_t num_read = in.Read(buffer.data(), std::min(buffer.size(), length));
     85     length -= num_read;
     86     if (num_read <= 0) {
     87       return false;
     88     }
     89     if (Write(buffer.data(), num_read) != num_read) {
     90       // The caller will have to log an appropriate message.
     91       return false;
     92     }
     93   }
     94   return true;
     95 }
     96 
     97 void FileInstance::Close() {
     98   AutoFreeBuffer message;
     99   if (fd_ == -1) {
    100     errno_ = EBADF;
    101   } else if (close(fd_) == -1) {
    102     errno_ = errno;
    103     if (identity_.size()) {
    104       message.PrintF("%s: %s failed (%s)", __FUNCTION__, identity_.data(),
    105                      StrError());
    106       Log(message.data());
    107     }
    108   } else {
    109     if (identity_.size()) {
    110       message.PrintF("%s: %s succeeded", __FUNCTION__, identity_.data());
    111       Log(message.data());
    112     }
    113   }
    114   fd_ = -1;
    115 }
    116 
    117 void FileInstance::Identify(const char* identity) {
    118   identity_.PrintF("fd=%d @%p is %s", fd_, this, identity);
    119   AutoFreeBuffer message;
    120   message.PrintF("%s: %s", __FUNCTION__, identity_.data());
    121   Log(message.data());
    122 }
    123 
    124 bool FileInstance::IsSet(fd_set* in) const {
    125   if (IsOpen() && FD_ISSET(fd_, in)) {
    126     return true;
    127   }
    128   return false;
    129 }
    130 
    131 #if ENABLE_GCE_SHARED_FD_LOGGING
    132 void FileInstance::Log(const char* message) {
    133   LOG(INFO) << message;
    134 }
    135 #else
    136 void FileInstance::Log(const char*) {}
    137 #endif
    138 
    139 void FileInstance::Set(fd_set* dest, int* max_index) const {
    140   if (!IsOpen()) {
    141     return;
    142   }
    143   if (fd_ >= *max_index) {
    144     *max_index = fd_ + 1;
    145   }
    146   FD_SET(fd_, dest);
    147 }
    148 
    149 int Select(SharedFDSet* read_set, SharedFDSet* write_set,
    150            SharedFDSet* error_set, struct timeval* timeout) {
    151   int max_index = 0;
    152   fd_set readfds;
    153   FD_ZERO(&readfds);
    154   if (read_set) {
    155     MarkAll(*read_set, &readfds, &max_index);
    156   }
    157   fd_set writefds;
    158   FD_ZERO(&writefds);
    159   if (write_set) {
    160     MarkAll(*write_set, &writefds, &max_index);
    161   }
    162   fd_set errorfds;
    163   FD_ZERO(&errorfds);
    164   if (error_set) {
    165     MarkAll(*error_set, &errorfds, &max_index);
    166   }
    167 
    168   int rval = TEMP_FAILURE_RETRY(
    169       select(max_index, &readfds, &writefds, &errorfds, timeout));
    170   FileInstance::Log("select\n");
    171   CheckMarked(&readfds, read_set);
    172   CheckMarked(&writefds, write_set);
    173   CheckMarked(&errorfds, error_set);
    174   return rval;
    175 }
    176 
    177 static void MakeAddress(const char* name, bool abstract,
    178                         struct sockaddr_un* dest, socklen_t* len) {
    179   memset(dest, 0, sizeof(*dest));
    180   dest->sun_family = AF_UNIX;
    181   // sun_path is NOT expected to be nul-terminated.
    182   // See man 7 unix.
    183   size_t namelen;
    184   if (abstract) {
    185     // ANDROID_SOCKET_NAMESPACE_ABSTRACT
    186     namelen = strlen(name);
    187     CHECK_LE(namelen, sizeof(dest->sun_path) - 1)
    188         << "MakeAddress failed. Name=" << name << " is longer than allowed.";
    189     dest->sun_path[0] = 0;
    190     memcpy(dest->sun_path + 1, name, namelen);
    191   } else {
    192     // ANDROID_SOCKET_NAMESPACE_RESERVED
    193     // ANDROID_SOCKET_NAMESPACE_FILESYSTEM
    194     // TODO(pinghao): Distinguish between them?
    195     namelen = strlen(name);
    196     CHECK_LE(namelen, sizeof(dest->sun_path))
    197         << "MakeAddress failed. Name=" << name << " is longer than allowed.";
    198     strncpy(dest->sun_path, name, strlen(name));
    199   }
    200   *len = namelen + offsetof(struct sockaddr_un, sun_path) + 1;
    201 }
    202 
    203 SharedFD SharedFD::SocketSeqPacketServer(const char* name, mode_t mode) {
    204   return SocketLocalServer(name, false, SOCK_SEQPACKET, mode);
    205 }
    206 
    207 SharedFD SharedFD::SocketSeqPacketClient(const char* name) {
    208   return SocketLocalClient(name, false, SOCK_SEQPACKET);
    209 }
    210 
    211 SharedFD SharedFD::TimerFD(int clock, int flags) {
    212   int fd = timerfd_create(clock, flags);
    213   if (fd == -1) {
    214     return SharedFD(std::shared_ptr<FileInstance>(new FileInstance(fd, errno)));
    215   } else {
    216     return SharedFD(std::shared_ptr<FileInstance>(new FileInstance(fd, 0)));
    217   }
    218 }
    219 
    220 SharedFD SharedFD::Accept(const FileInstance& listener, struct sockaddr* addr,
    221                           socklen_t* addrlen) {
    222   return SharedFD(
    223       std::shared_ptr<FileInstance>(listener.Accept(addr, addrlen)));
    224 }
    225 
    226 SharedFD SharedFD::Accept(const FileInstance& listener) {
    227   return SharedFD::Accept(listener, NULL, NULL);
    228 }
    229 
    230 SharedFD SharedFD::Dup(int unmanaged_fd) {
    231   int fd = dup(unmanaged_fd);
    232   return SharedFD(std::shared_ptr<FileInstance>(new FileInstance(fd, errno)));
    233 }
    234 
    235 bool SharedFD::Pipe(SharedFD* fd0, SharedFD* fd1) {
    236   int fds[2];
    237   int rval = pipe(fds);
    238   if (rval != -1) {
    239     (*fd0) = std::shared_ptr<FileInstance>(new FileInstance(fds[0], errno));
    240     (*fd1) = std::shared_ptr<FileInstance>(new FileInstance(fds[1], errno));
    241     return true;
    242   }
    243   return false;
    244 }
    245 
    246 SharedFD SharedFD::Event(int initval, int flags) {
    247   return std::shared_ptr<FileInstance>(
    248       new FileInstance(eventfd(initval, flags), errno));
    249 }
    250 
    251 SharedFD SharedFD::Epoll(int flags) {
    252   return std::shared_ptr<FileInstance>(
    253       new FileInstance(epoll_create1(flags), errno));
    254 }
    255 
    256 bool SharedFD::SocketPair(int domain, int type, int protocol,
    257                           SharedFD* fd0, SharedFD* fd1) {
    258   int fds[2];
    259   int rval = socketpair(domain, type, protocol, fds);
    260   if (rval != -1) {
    261     (*fd0) = std::shared_ptr<FileInstance>(new FileInstance(fds[0], errno));
    262     (*fd1) = std::shared_ptr<FileInstance>(new FileInstance(fds[1], errno));
    263     return true;
    264   }
    265   return false;
    266 }
    267 
    268 SharedFD SharedFD::Open(const char* path, int flags, mode_t mode) {
    269   int fd = TEMP_FAILURE_RETRY(open(path, flags, mode));
    270   if (fd == -1) {
    271     return SharedFD(std::shared_ptr<FileInstance>(new FileInstance(fd, errno)));
    272   } else {
    273     return SharedFD(std::shared_ptr<FileInstance>(new FileInstance(fd, 0)));
    274   }
    275 }
    276 
    277 SharedFD SharedFD::Creat(const char* path, mode_t mode) {
    278   return SharedFD::Open(path, O_CREAT|O_WRONLY|O_TRUNC, mode);
    279 }
    280 
    281 SharedFD SharedFD::Socket(int domain, int socket_type, int protocol) {
    282   int fd = TEMP_FAILURE_RETRY(socket(domain, socket_type, protocol));
    283   if (fd == -1) {
    284     return SharedFD(std::shared_ptr<FileInstance>(new FileInstance(fd, errno)));
    285   } else {
    286     return SharedFD(std::shared_ptr<FileInstance>(new FileInstance(fd, 0)));
    287   }
    288 }
    289 
    290 SharedFD SharedFD::ErrorFD(int error) {
    291   return SharedFD(std::shared_ptr<FileInstance>(new FileInstance(-1, error)));
    292 }
    293 
    294 SharedFD SharedFD::SocketLocalClient(const char* name, bool abstract,
    295                                      int in_type) {
    296   struct sockaddr_un addr;
    297   socklen_t addrlen;
    298   MakeAddress(name, abstract, &addr, &addrlen);
    299   SharedFD rval = SharedFD::Socket(PF_UNIX, in_type, 0);
    300   if (!rval->IsOpen()) {
    301     return rval;
    302   }
    303   if (rval->Connect(reinterpret_cast<sockaddr*>(&addr), addrlen) == -1) {
    304     return SharedFD::ErrorFD(rval->GetErrno());
    305   }
    306   return rval;
    307 }
    308 
    309 SharedFD SharedFD::SocketLocalClient(int port, int type) {
    310   sockaddr_in addr{};
    311   addr.sin_family = AF_INET;
    312   addr.sin_port = htons(port);
    313   addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
    314   SharedFD rval = SharedFD::Socket(AF_INET, type, 0);
    315   if (!rval->IsOpen()) {
    316     return rval;
    317   }
    318   if (rval->Connect(reinterpret_cast<const sockaddr*>(&addr),
    319                     sizeof addr) < 0) {
    320     return SharedFD::ErrorFD(rval->GetErrno());
    321   }
    322   return rval;
    323 }
    324 
    325 SharedFD SharedFD::SocketLocalServer(int port, int type) {
    326   struct sockaddr_in addr;
    327   memset(&addr, 0, sizeof(addr));
    328   addr.sin_family = AF_INET;
    329   addr.sin_port = htons(port);
    330   addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
    331   SharedFD rval = SharedFD::Socket(AF_INET, type, 0);
    332   if(!rval->IsOpen()) {
    333     return rval;
    334   }
    335   int n = 1;
    336   if (rval->SetSockOpt(SOL_SOCKET, SO_REUSEADDR, &n, sizeof(n)) == -1) {
    337     LOG(ERROR) << "SetSockOpt failed " << rval->StrError();
    338     return SharedFD::ErrorFD(rval->GetErrno());
    339   }
    340   if(rval->Bind(reinterpret_cast<sockaddr*>(&addr), sizeof(addr)) < 0) {
    341     LOG(ERROR) << "Bind failed " << rval->StrError();
    342     return SharedFD::ErrorFD(rval->GetErrno());
    343   }
    344   if (type == SOCK_STREAM) {
    345     if (rval->Listen(4) < 0) {
    346       LOG(ERROR) << "Listen failed " << rval->StrError();
    347       return SharedFD::ErrorFD(rval->GetErrno());
    348     }
    349   }
    350   return rval;
    351 }
    352 
    353 SharedFD SharedFD::SocketLocalServer(const char* name, bool abstract,
    354                                      int in_type, mode_t mode) {
    355   // DO NOT UNLINK addr.sun_path. It does NOT have to be null-terminated.
    356   // See man 7 unix for more details.
    357   if (!abstract) (void)unlink(name);
    358 
    359   struct sockaddr_un addr;
    360   socklen_t addrlen;
    361   MakeAddress(name, abstract, &addr, &addrlen);
    362   SharedFD rval = SharedFD::Socket(PF_UNIX, in_type, 0);
    363   if (!rval->IsOpen()) {
    364     return rval;
    365   }
    366 
    367   int n = 1;
    368   if (rval->SetSockOpt(SOL_SOCKET, SO_REUSEADDR, &n, sizeof(n)) == -1) {
    369     LOG(ERROR) << "SetSockOpt failed " << rval->StrError();
    370     return SharedFD::ErrorFD(rval->GetErrno());
    371   }
    372   if (rval->Bind(reinterpret_cast<sockaddr*>(&addr), addrlen) == -1) {
    373     LOG(ERROR) << "Bind failed; name=" << name << ": " << rval->StrError();
    374     return SharedFD::ErrorFD(rval->GetErrno());
    375   }
    376 
    377   /* Only the bottom bits are really the socket type; there are flags too. */
    378   constexpr int SOCK_TYPE_MASK = 0xf;
    379 
    380   // Connection oriented sockets: start listening.
    381   if ((in_type & SOCK_TYPE_MASK) == SOCK_STREAM) {
    382     // Follows the default from socket_local_server
    383     if (rval->Listen(1) == -1) {
    384       LOG(ERROR) << "Listen failed: " << rval->StrError();
    385       return SharedFD::ErrorFD(rval->GetErrno());
    386     }
    387   }
    388 
    389   if (!abstract) {
    390     if (TEMP_FAILURE_RETRY(chmod(name, mode)) == -1) {
    391       LOG(ERROR) << "chmod failed: " << strerror(errno);
    392       // However, continue since we do have a listening socket
    393     }
    394   }
    395   return rval;
    396 }
    397 
    398 SharedFD SharedFD::VsockServer(unsigned int port, int type) {
    399   auto vsock = cvd::SharedFD::Socket(AF_VSOCK, type, 0);
    400   if (!vsock->IsOpen()) {
    401     return vsock;
    402   }
    403   sockaddr_vm addr{};
    404   addr.svm_family = AF_VSOCK;
    405   addr.svm_port = port;
    406   addr.svm_cid = VMADDR_CID_ANY;
    407   auto casted_addr = reinterpret_cast<sockaddr*>(&addr);
    408   if (vsock->Bind(casted_addr, sizeof(addr)) == -1) {
    409     LOG(ERROR) << "Bind failed (" << vsock->StrError() << ")";
    410     return SharedFD::ErrorFD(vsock->GetErrno());
    411   }
    412   if (type == SOCK_STREAM) {
    413     if (vsock->Listen(4) < 0) {
    414       LOG(ERROR) << "Listen failed (" << vsock->StrError() << ")";
    415       return SharedFD::ErrorFD(vsock->GetErrno());
    416     }
    417   }
    418   return vsock;
    419 }
    420 
    421 SharedFD SharedFD::VsockClient(unsigned int cid, unsigned int port, int type) {
    422   auto vsock = cvd::SharedFD::Socket(AF_VSOCK, type, 0);
    423   if (!vsock->IsOpen()) {
    424     return vsock;
    425   }
    426   sockaddr_vm addr{};
    427   addr.svm_family = AF_VSOCK;
    428   addr.svm_port = port;
    429   addr.svm_cid = cid;
    430   auto casted_addr = reinterpret_cast<sockaddr*>(&addr);
    431   if (vsock->Connect(casted_addr, sizeof(addr)) == -1) {
    432     return SharedFD::ErrorFD(vsock->GetErrno());
    433   }
    434   return vsock;
    435 }
    436 
    437 }  // namespace cvd
    438