Home | History | Annotate | Download | only in lib
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 #pragma once
     17 
     18 #include <utility>
     19 #include <vector>
     20 #include <memory>
     21 
     22 #include "common/vsoc/lib/typed_region_view.h"
     23 #include "common/vsoc/shm/socket_forward_layout.h"
     24 
     25 namespace vsoc {
     26 namespace socket_forward {
     27 
     28 struct Header {
     29   std::uint32_t payload_length;
     30   std::uint32_t generation;
     31   enum MessageType : std::uint32_t {
     32     DATA = 0,
     33     BEGIN,
     34     END,
     35   };
     36   MessageType message_type;
     37 };
     38 
     39 constexpr std::size_t kMaxPayloadSize =
     40     layout::socket_forward::kMaxPacketSize - sizeof(Header);
     41 
     42 struct Packet {
     43  private:
     44   Header header_;
     45   using Payload = char[kMaxPayloadSize];
     46   Payload payload_data_;
     47 
     48   static Packet MakePacket(Header::MessageType type) {
     49     Packet packet{};
     50     packet.header_.message_type = type;
     51     return packet;
     52   }
     53 
     54  public:
     55   static Packet MakeBegin() { return MakePacket(Header::BEGIN); }
     56 
     57   static Packet MakeEnd() { return MakePacket(Header::END); }
     58 
     59   // NOTE payload and payload_length must still be set.
     60   static Packet MakeData() { return MakePacket(Header::DATA); }
     61 
     62   bool empty() const { return IsData() && header_.payload_length == 0; }
     63 
     64   void set_payload_length(std::uint32_t length) {
     65     CHECK_LE(length, sizeof payload_data_);
     66     header_.message_type = Header::DATA;
     67     header_.payload_length = length;
     68   }
     69 
     70   std::uint32_t generation() const { return header_.generation; }
     71 
     72   void set_generation(std::uint32_t generation) {
     73     header_.generation = generation;
     74   }
     75 
     76   Payload& payload() { return payload_data_; }
     77 
     78   const Payload& payload() const { return payload_data_; }
     79 
     80   std::uint32_t payload_length() const { return header_.payload_length; }
     81 
     82   bool IsBegin() const { return header_.message_type == Header::BEGIN; }
     83 
     84   bool IsEnd() const { return header_.message_type == Header::END; }
     85 
     86   bool IsData() const { return header_.message_type == Header::DATA; }
     87 
     88   char* raw_data() { return reinterpret_cast<char*>(this); }
     89 
     90   const char* raw_data() const { return reinterpret_cast<const char*>(this); }
     91 
     92   size_t raw_data_length() const { return payload_length() + sizeof header_; }
     93 };
     94 
     95 static_assert(sizeof(Packet) == layout::socket_forward::kMaxPacketSize, "");
     96 static_assert(std::is_pod<Packet>{}, "");
     97 
     98 // Data sent will start with a uint32_t indicating the number of bytes being
     99 // sent, followed be the data itself
    100 class SocketForwardRegionView
    101     : public TypedRegionView<SocketForwardRegionView,
    102                              layout::socket_forward::SocketForwardLayout> {
    103  private:
    104 #ifdef CUTTLEFISH_HOST
    105   int AcquireConnectionID(int port);
    106 #else
    107   int GetWaitingConnectionID();
    108 #endif
    109 
    110   // Returns an empty data packet if the other side is closed.
    111   void Recv(int connection_id, Packet* packet);
    112   // Returns true on success
    113   bool Send(int connection_id, const Packet& packet);
    114 
    115   // skip everything in the connection queue until seeing a BEGIN for the
    116   // current generation
    117   void IgnoreUntilBegin(int connection_id, std::uint32_t generation);
    118 
    119   bool IsOtherSideRecvClosed(int connection_id);
    120 
    121   void ResetQueueStates(layout::socket_forward::QueuePair* queue_pair);
    122 
    123   void MarkQueueDisconnected(int connection_id,
    124                              layout::socket_forward::Queue
    125                                  layout::socket_forward::QueuePair::*direction);
    126 
    127  public:
    128   // Helper class that will send a ConnectionBegin marker when constructed and a
    129   // ConnectionEnd marker when destroyed.
    130   class Sender {
    131    public:
    132     explicit Sender(SocketForwardRegionView* view, int connection_id,
    133                     std::uint32_t generation)
    134         : view_{view, {connection_id, generation}},
    135           connection_id_{connection_id} {
    136       auto packet = Packet::MakeBegin();
    137       packet.set_generation(generation);
    138       view_->Send(connection_id, packet);
    139     }
    140 
    141     Sender(const Sender&) = delete;
    142     Sender& operator=(const Sender&) = delete;
    143 
    144     Sender(Sender&&) = default;
    145     Sender& operator=(Sender&&) = default;
    146     ~Sender() = default;
    147 
    148     // Returns true on success
    149     bool Send(const Packet& packet);
    150     int port() const { return view_->port(connection_id_); }
    151 
    152    private:
    153     bool closed() const;
    154 
    155     struct EndSender {
    156       int connection_id = -1;
    157       std::uint32_t generation{};
    158       void operator()(SocketForwardRegionView* view) const {
    159         if (view) {
    160           CHECK(connection_id >= 0);
    161           auto packet = Packet::MakeEnd();
    162           packet.set_generation(generation);
    163           view->Send(connection_id, packet);
    164           view->MarkSendQueueDisconnected(connection_id);
    165         }
    166       }
    167     };
    168     // Doesn't actually own the View, responsible for sending the End
    169     // indicator and marking the sending side as disconnected.
    170     std::unique_ptr<SocketForwardRegionView, EndSender> view_;
    171     int connection_id_{};
    172   };
    173 
    174   class Receiver {
    175    public:
    176     explicit Receiver(SocketForwardRegionView* view, int connection_id,
    177                       std::uint32_t generation)
    178         : view_{view, {connection_id}},
    179           connection_id_{connection_id},
    180           generation_{generation} {}
    181     Receiver(const Receiver&) = delete;
    182     Receiver& operator=(const Receiver&) = delete;
    183 
    184     Receiver(Receiver&&) = default;
    185     Receiver& operator=(Receiver&&) = default;
    186     ~Receiver() = default;
    187 
    188     void Recv(Packet* packet);
    189     int port() const { return view_->port(connection_id_); }
    190 
    191    private:
    192     struct QueueCloser {
    193       int connection_id = -1;
    194       void operator()(SocketForwardRegionView* view) const {
    195         if (view) {
    196           CHECK(connection_id >= 0);
    197           view->MarkRecvQueueDisconnected(connection_id);
    198         }
    199       }
    200     };
    201 
    202     // Doesn't actually own the View, responsible for marking the receiving
    203     // side as disconnected
    204     std::unique_ptr<SocketForwardRegionView, QueueCloser> view_;
    205     int connection_id_{};
    206     std::uint32_t generation_{};
    207     bool got_begin_ = false;
    208   };
    209 
    210   SocketForwardRegionView() = default;
    211   ~SocketForwardRegionView() = default;
    212   SocketForwardRegionView(const SocketForwardRegionView&) = delete;
    213   SocketForwardRegionView& operator=(const SocketForwardRegionView&) = delete;
    214 
    215 #ifdef CUTTLEFISH_HOST
    216   std::pair<Sender, Receiver> OpenConnection(int port);
    217 #else
    218   std::pair<Sender, Receiver> AcceptConnection();
    219 #endif
    220 
    221   int port(int connection_id);
    222   std::uint32_t generation();
    223   void CleanUpPreviousConnections();
    224   void MarkSendQueueDisconnected(int connection_id);
    225   void MarkRecvQueueDisconnected(int connection_id);
    226 
    227  private:
    228 #ifndef CUTTLEFISH_HOST
    229   std::uint32_t last_seq_number_{};
    230 #endif
    231 };
    232 
    233 }  // namespace socket_forward
    234 }  // namespace vsoc
    235