1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include <cassert> 18 19 #include "common/vsoc/lib/circqueue_impl.h" 20 #include "common/vsoc/lib/lock_guard.h" 21 #include "common/vsoc/lib/socket_forward_region_view.h" 22 #include "common/vsoc/shm/lock.h" 23 #include "common/vsoc/shm/socket_forward_layout.h" 24 25 using vsoc::layout::socket_forward::Queue; 26 using vsoc::layout::socket_forward::QueuePair; 27 using vsoc::layout::socket_forward::QueueState; 28 // store the read and write direction as variables to keep the ifdefs and macros 29 // in later code to a minimum 30 constexpr auto ReadDirection = &QueuePair:: 31 #ifdef CUTTLEFISH_HOST 32 guest_to_host; 33 #else 34 host_to_guest; 35 #endif 36 37 constexpr auto WriteDirection = &QueuePair:: 38 #ifdef CUTTLEFISH_HOST 39 host_to_guest; 40 #else 41 guest_to_host; 42 #endif 43 44 constexpr auto kOtherSideClosed = QueueState:: 45 #ifdef CUTTLEFISH_HOST 46 GUEST_CLOSED; 47 #else 48 HOST_CLOSED; 49 #endif 50 51 constexpr auto kThisSideClosed = QueueState:: 52 #ifdef CUTTLEFISH_HOST 53 HOST_CLOSED; 54 #else 55 GUEST_CLOSED; 56 #endif 57 58 using vsoc::socket_forward::SocketForwardRegionView; 59 60 void SocketForwardRegionView::Recv(int connection_id, Packet* packet) { 61 CHECK(packet != nullptr); 62 do { 63 (data()->queues_[connection_id].*ReadDirection) 64 .queue.Read(this, reinterpret_cast<char*>(packet), sizeof *packet); 65 } while (packet->IsBegin()); 66 // TODO(haining) check packet generation number 67 CHECK(!packet->empty()) << "zero-size data message received"; 68 CHECK_LE(packet->payload_length(), kMaxPayloadSize) << "invalid size"; 69 } 70 71 bool SocketForwardRegionView::Send(int connection_id, const Packet& packet) { 72 CHECK(!packet.empty()); 73 CHECK_LE(packet.payload_length(), kMaxPayloadSize); 74 75 // NOTE this is check-then-act but I think that it's okay. Worst case is that 76 // we send one-too-many packets. 77 auto& queue_pair = data()->queues_[connection_id]; 78 { 79 auto guard = make_lock_guard(&queue_pair.queue_state_lock_); 80 if ((queue_pair.*WriteDirection).queue_state_ == kOtherSideClosed) { 81 LOG(INFO) << "connection closed, not sending\n"; 82 return false; 83 } 84 CHECK((queue_pair.*WriteDirection).queue_state_ != QueueState::INACTIVE); 85 } 86 // TODO(haining) set packet generation number 87 (data()->queues_[connection_id].*WriteDirection) 88 .queue.Write(this, packet.raw_data(), packet.raw_data_length()); 89 return true; 90 } 91 92 void SocketForwardRegionView::IgnoreUntilBegin(int connection_id, 93 std::uint32_t generation) { 94 Packet packet{}; 95 do { 96 (data()->queues_[connection_id].*ReadDirection) 97 .queue.Read(this, reinterpret_cast<char*>(&packet), sizeof packet); 98 } while (!packet.IsBegin() || packet.generation() < generation); 99 } 100 101 bool SocketForwardRegionView::IsOtherSideRecvClosed(int connection_id) { 102 auto& queue_pair = data()->queues_[connection_id]; 103 auto guard = make_lock_guard(&queue_pair.queue_state_lock_); 104 auto& queue = queue_pair.*WriteDirection; 105 return queue.queue_state_ == kOtherSideClosed || 106 queue.queue_state_ == QueueState::INACTIVE; 107 } 108 109 void SocketForwardRegionView::ResetQueueStates(QueuePair* queue_pair) { 110 using vsoc::layout::socket_forward::Queue; 111 auto guard = make_lock_guard(&queue_pair->queue_state_lock_); 112 Queue* queues[] = {&queue_pair->host_to_guest, &queue_pair->guest_to_host}; 113 for (auto* queue : queues) { 114 auto& state = queue->queue_state_; 115 switch (state) { 116 case QueueState::HOST_CONNECTED: 117 case kOtherSideClosed: 118 LOG(DEBUG) 119 << "host_connected or other side is closed, marking inactive"; 120 state = QueueState::INACTIVE; 121 break; 122 123 case QueueState::BOTH_CONNECTED: 124 LOG(DEBUG) << "both_connected, marking this side closed"; 125 state = kThisSideClosed; 126 break; 127 128 case kThisSideClosed: 129 [[fallthrough]]; 130 case QueueState::INACTIVE: 131 LOG(DEBUG) << "inactive or this side closed, not changing state"; 132 break; 133 } 134 } 135 } 136 137 void SocketForwardRegionView::CleanUpPreviousConnections() { 138 data()->Recover(); 139 int connection_id = 0; 140 auto current_generation = generation(); 141 auto begin_packet = Packet::MakeBegin(); 142 begin_packet.set_generation(current_generation); 143 auto end_packet = Packet::MakeEnd(); 144 end_packet.set_generation(current_generation); 145 for (auto&& queue_pair : data()->queues_) { 146 QueueState state{}; 147 { 148 auto guard = make_lock_guard(&queue_pair.queue_state_lock_); 149 state = (queue_pair.*WriteDirection).queue_state_; 150 #ifndef CUTTLEFISH_HOST 151 if (state == QueueState::HOST_CONNECTED) { 152 state = (queue_pair.*WriteDirection).queue_state_ = 153 (queue_pair.*ReadDirection).queue_state_ = 154 QueueState::BOTH_CONNECTED; 155 } 156 #endif 157 } 158 159 if (state == QueueState::BOTH_CONNECTED 160 #ifdef CUTTLEFISH_HOST 161 || state == QueueState::HOST_CONNECTED 162 #endif 163 ) { 164 LOG(INFO) << "found connected write queue state, sending begin and end"; 165 Send(connection_id, begin_packet); 166 Send(connection_id, end_packet); 167 } 168 ResetQueueStates(&queue_pair); 169 ++connection_id; 170 } 171 ++data()->generation_num; 172 } 173 174 void SocketForwardRegionView::MarkQueueDisconnected( 175 int connection_id, Queue QueuePair::*direction) { 176 auto& queue_pair = data()->queues_[connection_id]; 177 auto& queue = queue_pair.*direction; 178 179 #ifdef CUTTLEFISH_HOST 180 // if the host has connected but the guest hasn't seen it yet, wait for the 181 // guest to connect so the protocol can follow the normal state transition. 182 while (true) { 183 auto guard = make_lock_guard(&queue_pair.queue_state_lock_); 184 if (queue.queue_state_ != QueueState::HOST_CONNECTED) { 185 break; 186 } 187 LOG(WARNING) << "closing queue in HOST_CONNECTED state. waiting"; 188 sleep(1); 189 } 190 #endif 191 192 auto guard = make_lock_guard(&queue_pair.queue_state_lock_); 193 194 queue.queue_state_ = queue.queue_state_ == kOtherSideClosed 195 ? QueueState::INACTIVE 196 : kThisSideClosed; 197 } 198 199 void SocketForwardRegionView::MarkSendQueueDisconnected(int connection_id) { 200 MarkQueueDisconnected(connection_id, WriteDirection); 201 } 202 203 void SocketForwardRegionView::MarkRecvQueueDisconnected(int connection_id) { 204 MarkQueueDisconnected(connection_id, ReadDirection); 205 } 206 207 int SocketForwardRegionView::port(int connection_id) { 208 return data()->queues_[connection_id].port_; 209 } 210 211 std::uint32_t SocketForwardRegionView::generation() { 212 return data()->generation_num; 213 } 214 215 #ifdef CUTTLEFISH_HOST 216 int SocketForwardRegionView::AcquireConnectionID(int port) { 217 while (true) { 218 int id = 0; 219 for (auto&& queue_pair : data()->queues_) { 220 LOG(DEBUG) << "locking and checking queue at index " << id; 221 auto guard = make_lock_guard(&queue_pair.queue_state_lock_); 222 if (queue_pair.host_to_guest.queue_state_ == QueueState::INACTIVE && 223 queue_pair.guest_to_host.queue_state_ == QueueState::INACTIVE) { 224 queue_pair.port_ = port; 225 queue_pair.host_to_guest.queue_state_ = QueueState::HOST_CONNECTED; 226 queue_pair.guest_to_host.queue_state_ = QueueState::HOST_CONNECTED; 227 LOG(DEBUG) << "acquired queue " << id 228 << ". current seq_num: " << data()->seq_num; 229 ++data()->seq_num; 230 SendSignal(layout::Sides::Peer, &data()->seq_num); 231 return id; 232 } 233 ++id; 234 } 235 LOG(ERROR) << "no remaining shm queues for connection, sleeping."; 236 sleep(10); 237 } 238 } 239 240 std::pair<SocketForwardRegionView::Sender, SocketForwardRegionView::Receiver> 241 SocketForwardRegionView::OpenConnection(int port) { 242 int connection_id = AcquireConnectionID(port); 243 LOG(INFO) << "Acquired connection with id " << connection_id; 244 auto current_generation = generation(); 245 return {Sender{this, connection_id, current_generation}, 246 Receiver{this, connection_id, current_generation}}; 247 } 248 #else 249 int SocketForwardRegionView::GetWaitingConnectionID() { 250 while (data()->seq_num == last_seq_number_) { 251 WaitForSignal(&data()->seq_num, last_seq_number_); 252 } 253 ++last_seq_number_; 254 int id = 0; 255 for (auto&& queue_pair : data()->queues_) { 256 LOG(DEBUG) << "locking and checking queue at index " << id; 257 auto guard = make_lock_guard(&queue_pair.queue_state_lock_); 258 if (queue_pair.host_to_guest.queue_state_ == QueueState::HOST_CONNECTED) { 259 CHECK(queue_pair.guest_to_host.queue_state_ == 260 QueueState::HOST_CONNECTED); 261 LOG(DEBUG) << "found waiting connection at index " << id; 262 queue_pair.host_to_guest.queue_state_ = QueueState::BOTH_CONNECTED; 263 queue_pair.guest_to_host.queue_state_ = QueueState::BOTH_CONNECTED; 264 return id; 265 } 266 ++id; 267 } 268 return -1; 269 } 270 271 std::pair<SocketForwardRegionView::Sender, SocketForwardRegionView::Receiver> 272 SocketForwardRegionView::AcceptConnection() { 273 int connection_id = -1; 274 while (connection_id < 0) { 275 connection_id = GetWaitingConnectionID(); 276 } 277 LOG(INFO) << "Accepted connection with id " << connection_id; 278 279 auto current_generation = generation(); 280 return {Sender{this, connection_id, current_generation}, 281 Receiver{this, connection_id, current_generation}}; 282 } 283 #endif 284 285 // --- Connection ---- // 286 void SocketForwardRegionView::Receiver::Recv(Packet* packet) { 287 if (!got_begin_) { 288 view_->IgnoreUntilBegin(connection_id_, generation_); 289 got_begin_ = true; 290 } 291 return view_->Recv(connection_id_, packet); 292 } 293 294 bool SocketForwardRegionView::Sender::closed() const { 295 return view_->IsOtherSideRecvClosed(connection_id_); 296 } 297 298 bool SocketForwardRegionView::Sender::Send(const Packet& packet) { 299 return view_->Send(connection_id_, packet); 300 } 301