1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #pragma once 17 18 #include <utility> 19 #include <vector> 20 #include <memory> 21 22 #include "common/vsoc/lib/typed_region_view.h" 23 #include "common/vsoc/shm/socket_forward_layout.h" 24 25 namespace vsoc { 26 namespace socket_forward { 27 28 struct Header { 29 std::uint32_t payload_length; 30 std::uint32_t generation; 31 enum MessageType : std::uint32_t { 32 DATA = 0, 33 BEGIN, 34 END, 35 }; 36 MessageType message_type; 37 }; 38 39 constexpr std::size_t kMaxPayloadSize = 40 layout::socket_forward::kMaxPacketSize - sizeof(Header); 41 42 struct Packet { 43 private: 44 Header header_; 45 using Payload = char[kMaxPayloadSize]; 46 Payload payload_data_; 47 48 static Packet MakePacket(Header::MessageType type) { 49 Packet packet{}; 50 packet.header_.message_type = type; 51 return packet; 52 } 53 54 public: 55 static Packet MakeBegin() { return MakePacket(Header::BEGIN); } 56 57 static Packet MakeEnd() { return MakePacket(Header::END); } 58 59 // NOTE payload and payload_length must still be set. 60 static Packet MakeData() { return MakePacket(Header::DATA); } 61 62 bool empty() const { return IsData() && header_.payload_length == 0; } 63 64 void set_payload_length(std::uint32_t length) { 65 CHECK_LE(length, sizeof payload_data_); 66 header_.message_type = Header::DATA; 67 header_.payload_length = length; 68 } 69 70 std::uint32_t generation() const { return header_.generation; } 71 72 void set_generation(std::uint32_t generation) { 73 header_.generation = generation; 74 } 75 76 Payload& payload() { return payload_data_; } 77 78 const Payload& payload() const { return payload_data_; } 79 80 std::uint32_t payload_length() const { return header_.payload_length; } 81 82 bool IsBegin() const { return header_.message_type == Header::BEGIN; } 83 84 bool IsEnd() const { return header_.message_type == Header::END; } 85 86 bool IsData() const { return header_.message_type == Header::DATA; } 87 88 char* raw_data() { return reinterpret_cast<char*>(this); } 89 90 const char* raw_data() const { return reinterpret_cast<const char*>(this); } 91 92 size_t raw_data_length() const { return payload_length() + sizeof header_; } 93 }; 94 95 static_assert(sizeof(Packet) == layout::socket_forward::kMaxPacketSize, ""); 96 static_assert(std::is_pod<Packet>{}, ""); 97 98 // Data sent will start with a uint32_t indicating the number of bytes being 99 // sent, followed be the data itself 100 class SocketForwardRegionView 101 : public TypedRegionView<SocketForwardRegionView, 102 layout::socket_forward::SocketForwardLayout> { 103 private: 104 #ifdef CUTTLEFISH_HOST 105 int AcquireConnectionID(int port); 106 #else 107 int GetWaitingConnectionID(); 108 #endif 109 110 // Returns an empty data packet if the other side is closed. 111 void Recv(int connection_id, Packet* packet); 112 // Returns true on success 113 bool Send(int connection_id, const Packet& packet); 114 115 // skip everything in the connection queue until seeing a BEGIN for the 116 // current generation 117 void IgnoreUntilBegin(int connection_id, std::uint32_t generation); 118 119 bool IsOtherSideRecvClosed(int connection_id); 120 121 void ResetQueueStates(layout::socket_forward::QueuePair* queue_pair); 122 123 void MarkQueueDisconnected(int connection_id, 124 layout::socket_forward::Queue 125 layout::socket_forward::QueuePair::*direction); 126 127 public: 128 // Helper class that will send a ConnectionBegin marker when constructed and a 129 // ConnectionEnd marker when destroyed. 130 class Sender { 131 public: 132 explicit Sender(SocketForwardRegionView* view, int connection_id, 133 std::uint32_t generation) 134 : view_{view, {connection_id, generation}}, 135 connection_id_{connection_id} { 136 auto packet = Packet::MakeBegin(); 137 packet.set_generation(generation); 138 view_->Send(connection_id, packet); 139 } 140 141 Sender(const Sender&) = delete; 142 Sender& operator=(const Sender&) = delete; 143 144 Sender(Sender&&) = default; 145 Sender& operator=(Sender&&) = default; 146 ~Sender() = default; 147 148 // Returns true on success 149 bool Send(const Packet& packet); 150 int port() const { return view_->port(connection_id_); } 151 152 private: 153 bool closed() const; 154 155 struct EndSender { 156 int connection_id = -1; 157 std::uint32_t generation{}; 158 void operator()(SocketForwardRegionView* view) const { 159 if (view) { 160 CHECK(connection_id >= 0); 161 auto packet = Packet::MakeEnd(); 162 packet.set_generation(generation); 163 view->Send(connection_id, packet); 164 view->MarkSendQueueDisconnected(connection_id); 165 } 166 } 167 }; 168 // Doesn't actually own the View, responsible for sending the End 169 // indicator and marking the sending side as disconnected. 170 std::unique_ptr<SocketForwardRegionView, EndSender> view_; 171 int connection_id_{}; 172 }; 173 174 class Receiver { 175 public: 176 explicit Receiver(SocketForwardRegionView* view, int connection_id, 177 std::uint32_t generation) 178 : view_{view, {connection_id}}, 179 connection_id_{connection_id}, 180 generation_{generation} {} 181 Receiver(const Receiver&) = delete; 182 Receiver& operator=(const Receiver&) = delete; 183 184 Receiver(Receiver&&) = default; 185 Receiver& operator=(Receiver&&) = default; 186 ~Receiver() = default; 187 188 void Recv(Packet* packet); 189 int port() const { return view_->port(connection_id_); } 190 191 private: 192 struct QueueCloser { 193 int connection_id = -1; 194 void operator()(SocketForwardRegionView* view) const { 195 if (view) { 196 CHECK(connection_id >= 0); 197 view->MarkRecvQueueDisconnected(connection_id); 198 } 199 } 200 }; 201 202 // Doesn't actually own the View, responsible for marking the receiving 203 // side as disconnected 204 std::unique_ptr<SocketForwardRegionView, QueueCloser> view_; 205 int connection_id_{}; 206 std::uint32_t generation_{}; 207 bool got_begin_ = false; 208 }; 209 210 SocketForwardRegionView() = default; 211 ~SocketForwardRegionView() = default; 212 SocketForwardRegionView(const SocketForwardRegionView&) = delete; 213 SocketForwardRegionView& operator=(const SocketForwardRegionView&) = delete; 214 215 #ifdef CUTTLEFISH_HOST 216 std::pair<Sender, Receiver> OpenConnection(int port); 217 #else 218 std::pair<Sender, Receiver> AcceptConnection(); 219 #endif 220 221 int port(int connection_id); 222 std::uint32_t generation(); 223 void CleanUpPreviousConnections(); 224 void MarkSendQueueDisconnected(int connection_id); 225 void MarkRecvQueueDisconnected(int connection_id); 226 227 private: 228 #ifndef CUTTLEFISH_HOST 229 std::uint32_t last_seq_number_{}; 230 #endif 231 }; 232 233 } // namespace socket_forward 234 } // namespace vsoc 235