Home | History | Annotate | Download | only in lib
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include <cassert>
     18 
     19 #include "common/vsoc/lib/circqueue_impl.h"
     20 #include "common/vsoc/lib/lock_guard.h"
     21 #include "common/vsoc/lib/socket_forward_region_view.h"
     22 #include "common/vsoc/shm/lock.h"
     23 #include "common/vsoc/shm/socket_forward_layout.h"
     24 
     25 using vsoc::layout::socket_forward::Queue;
     26 using vsoc::layout::socket_forward::QueuePair;
     27 using vsoc::layout::socket_forward::QueueState;
     28 // store the read and write direction as variables to keep the ifdefs and macros
     29 // in later code to a minimum
     30 constexpr auto ReadDirection = &QueuePair::
     31 #ifdef CUTTLEFISH_HOST
     32                                    guest_to_host;
     33 #else
     34                                    host_to_guest;
     35 #endif
     36 
     37 constexpr auto WriteDirection = &QueuePair::
     38 #ifdef CUTTLEFISH_HOST
     39                                     host_to_guest;
     40 #else
     41                                     guest_to_host;
     42 #endif
     43 
     44 constexpr auto kOtherSideClosed = QueueState::
     45 #ifdef CUTTLEFISH_HOST
     46     GUEST_CLOSED;
     47 #else
     48     HOST_CLOSED;
     49 #endif
     50 
     51 constexpr auto kThisSideClosed = QueueState::
     52 #ifdef CUTTLEFISH_HOST
     53     HOST_CLOSED;
     54 #else
     55     GUEST_CLOSED;
     56 #endif
     57 
     58 using vsoc::socket_forward::SocketForwardRegionView;
     59 
     60 void SocketForwardRegionView::Recv(int connection_id, Packet* packet) {
     61   CHECK(packet != nullptr);
     62   do {
     63     (data()->queues_[connection_id].*ReadDirection)
     64         .queue.Read(this, reinterpret_cast<char*>(packet), sizeof *packet);
     65   } while (packet->IsBegin());
     66   // TODO(haining) check packet generation number
     67   CHECK(!packet->empty()) << "zero-size data message received";
     68   CHECK_LE(packet->payload_length(), kMaxPayloadSize) << "invalid size";
     69 }
     70 
     71 bool SocketForwardRegionView::Send(int connection_id, const Packet& packet) {
     72   CHECK(!packet.empty());
     73   CHECK_LE(packet.payload_length(), kMaxPayloadSize);
     74 
     75   // NOTE this is check-then-act but I think that it's okay. Worst case is that
     76   // we send one-too-many packets.
     77   auto& queue_pair = data()->queues_[connection_id];
     78   {
     79     auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
     80     if ((queue_pair.*WriteDirection).queue_state_ == kOtherSideClosed) {
     81       LOG(INFO) << "connection closed, not sending\n";
     82       return false;
     83     }
     84     CHECK((queue_pair.*WriteDirection).queue_state_ != QueueState::INACTIVE);
     85   }
     86   // TODO(haining) set packet generation number
     87   (data()->queues_[connection_id].*WriteDirection)
     88       .queue.Write(this, packet.raw_data(), packet.raw_data_length());
     89   return true;
     90 }
     91 
     92 void SocketForwardRegionView::IgnoreUntilBegin(int connection_id,
     93                                                std::uint32_t generation) {
     94   Packet packet{};
     95   do {
     96     (data()->queues_[connection_id].*ReadDirection)
     97         .queue.Read(this, reinterpret_cast<char*>(&packet), sizeof packet);
     98   } while (!packet.IsBegin() || packet.generation() < generation);
     99 }
    100 
    101 bool SocketForwardRegionView::IsOtherSideRecvClosed(int connection_id) {
    102   auto& queue_pair = data()->queues_[connection_id];
    103   auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
    104   auto& queue = queue_pair.*WriteDirection;
    105   return queue.queue_state_ == kOtherSideClosed ||
    106          queue.queue_state_ == QueueState::INACTIVE;
    107 }
    108 
    109 void SocketForwardRegionView::ResetQueueStates(QueuePair* queue_pair) {
    110   using vsoc::layout::socket_forward::Queue;
    111   auto guard = make_lock_guard(&queue_pair->queue_state_lock_);
    112   Queue* queues[] = {&queue_pair->host_to_guest, &queue_pair->guest_to_host};
    113   for (auto* queue : queues) {
    114     auto& state = queue->queue_state_;
    115     switch (state) {
    116       case QueueState::HOST_CONNECTED:
    117       case kOtherSideClosed:
    118         LOG(DEBUG)
    119             << "host_connected or other side is closed, marking inactive";
    120         state = QueueState::INACTIVE;
    121         break;
    122 
    123       case QueueState::BOTH_CONNECTED:
    124         LOG(DEBUG) << "both_connected, marking this side closed";
    125         state = kThisSideClosed;
    126         break;
    127 
    128       case kThisSideClosed:
    129         [[fallthrough]];
    130       case QueueState::INACTIVE:
    131         LOG(DEBUG) << "inactive or this side closed, not changing state";
    132         break;
    133     }
    134   }
    135 }
    136 
    137 void SocketForwardRegionView::CleanUpPreviousConnections() {
    138   data()->Recover();
    139   int connection_id = 0;
    140   auto current_generation = generation();
    141   auto begin_packet = Packet::MakeBegin();
    142   begin_packet.set_generation(current_generation);
    143   auto end_packet = Packet::MakeEnd();
    144   end_packet.set_generation(current_generation);
    145   for (auto&& queue_pair : data()->queues_) {
    146     QueueState state{};
    147     {
    148       auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
    149       state = (queue_pair.*WriteDirection).queue_state_;
    150 #ifndef CUTTLEFISH_HOST
    151       if (state == QueueState::HOST_CONNECTED) {
    152         state = (queue_pair.*WriteDirection).queue_state_ =
    153             (queue_pair.*ReadDirection).queue_state_ =
    154                 QueueState::BOTH_CONNECTED;
    155       }
    156 #endif
    157     }
    158 
    159     if (state == QueueState::BOTH_CONNECTED
    160 #ifdef CUTTLEFISH_HOST
    161         || state == QueueState::HOST_CONNECTED
    162 #endif
    163     ) {
    164       LOG(INFO) << "found connected write queue state, sending begin and end";
    165       Send(connection_id, begin_packet);
    166       Send(connection_id, end_packet);
    167     }
    168     ResetQueueStates(&queue_pair);
    169     ++connection_id;
    170   }
    171   ++data()->generation_num;
    172 }
    173 
    174 void SocketForwardRegionView::MarkQueueDisconnected(
    175     int connection_id, Queue QueuePair::*direction) {
    176   auto& queue_pair = data()->queues_[connection_id];
    177   auto& queue = queue_pair.*direction;
    178 
    179 #ifdef CUTTLEFISH_HOST
    180   // if the host has connected but the guest hasn't seen it yet, wait for the
    181   // guest to connect so the protocol can follow the normal state transition.
    182   while (true) {
    183     auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
    184     if (queue.queue_state_ != QueueState::HOST_CONNECTED) {
    185       break;
    186     }
    187     LOG(WARNING) << "closing queue in HOST_CONNECTED state. waiting";
    188     sleep(1);
    189   }
    190 #endif
    191 
    192   auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
    193 
    194   queue.queue_state_ = queue.queue_state_ == kOtherSideClosed
    195                            ? QueueState::INACTIVE
    196                            : kThisSideClosed;
    197 }
    198 
    199 void SocketForwardRegionView::MarkSendQueueDisconnected(int connection_id) {
    200   MarkQueueDisconnected(connection_id, WriteDirection);
    201 }
    202 
    203 void SocketForwardRegionView::MarkRecvQueueDisconnected(int connection_id) {
    204   MarkQueueDisconnected(connection_id, ReadDirection);
    205 }
    206 
    207 int SocketForwardRegionView::port(int connection_id) {
    208   return data()->queues_[connection_id].port_;
    209 }
    210 
    211 std::uint32_t SocketForwardRegionView::generation() {
    212   return data()->generation_num;
    213 }
    214 
    215 #ifdef CUTTLEFISH_HOST
    216 int SocketForwardRegionView::AcquireConnectionID(int port) {
    217   while (true) {
    218     int id = 0;
    219     for (auto&& queue_pair : data()->queues_) {
    220       LOG(DEBUG) << "locking and checking queue at index " << id;
    221       auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
    222       if (queue_pair.host_to_guest.queue_state_ == QueueState::INACTIVE &&
    223           queue_pair.guest_to_host.queue_state_ == QueueState::INACTIVE) {
    224         queue_pair.port_ = port;
    225         queue_pair.host_to_guest.queue_state_ = QueueState::HOST_CONNECTED;
    226         queue_pair.guest_to_host.queue_state_ = QueueState::HOST_CONNECTED;
    227         LOG(DEBUG) << "acquired queue " << id
    228                    << ". current seq_num: " << data()->seq_num;
    229         ++data()->seq_num;
    230         SendSignal(layout::Sides::Peer, &data()->seq_num);
    231         return id;
    232       }
    233       ++id;
    234     }
    235     LOG(ERROR) << "no remaining shm queues for connection, sleeping.";
    236     sleep(10);
    237   }
    238 }
    239 
    240 std::pair<SocketForwardRegionView::Sender, SocketForwardRegionView::Receiver>
    241 SocketForwardRegionView::OpenConnection(int port) {
    242   int connection_id = AcquireConnectionID(port);
    243   LOG(INFO) << "Acquired connection with id " << connection_id;
    244   auto current_generation = generation();
    245   return {Sender{this, connection_id, current_generation},
    246           Receiver{this, connection_id, current_generation}};
    247 }
    248 #else
    249 int SocketForwardRegionView::GetWaitingConnectionID() {
    250   while (data()->seq_num == last_seq_number_) {
    251     WaitForSignal(&data()->seq_num, last_seq_number_);
    252   }
    253   ++last_seq_number_;
    254   int id = 0;
    255   for (auto&& queue_pair : data()->queues_) {
    256     LOG(DEBUG) << "locking and checking queue at index " << id;
    257     auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
    258     if (queue_pair.host_to_guest.queue_state_ == QueueState::HOST_CONNECTED) {
    259       CHECK(queue_pair.guest_to_host.queue_state_ ==
    260             QueueState::HOST_CONNECTED);
    261       LOG(DEBUG) << "found waiting connection at index " << id;
    262       queue_pair.host_to_guest.queue_state_ = QueueState::BOTH_CONNECTED;
    263       queue_pair.guest_to_host.queue_state_ = QueueState::BOTH_CONNECTED;
    264       return id;
    265     }
    266     ++id;
    267   }
    268   return -1;
    269 }
    270 
    271 std::pair<SocketForwardRegionView::Sender, SocketForwardRegionView::Receiver>
    272 SocketForwardRegionView::AcceptConnection() {
    273   int connection_id = -1;
    274   while (connection_id < 0) {
    275     connection_id = GetWaitingConnectionID();
    276   }
    277   LOG(INFO) << "Accepted connection with id " << connection_id;
    278 
    279   auto current_generation = generation();
    280   return {Sender{this, connection_id, current_generation},
    281           Receiver{this, connection_id, current_generation}};
    282 }
    283 #endif
    284 
    285 // --- Connection ---- //
    286 void SocketForwardRegionView::Receiver::Recv(Packet* packet) {
    287   if (!got_begin_) {
    288     view_->IgnoreUntilBegin(connection_id_, generation_);
    289     got_begin_ = true;
    290   }
    291   return view_->Recv(connection_id_, packet);
    292 }
    293 
    294 bool SocketForwardRegionView::Sender::closed() const {
    295   return view_->IsOtherSideRecvClosed(connection_id_);
    296 }
    297 
    298 bool SocketForwardRegionView::Sender::Send(const Packet& packet) {
    299   return view_->Send(connection_id_, packet);
    300 }
    301