1 // Copyright 2014 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "mojo/system/raw_channel.h" 6 7 #include <string.h> 8 9 #include <algorithm> 10 11 #include "base/bind.h" 12 #include "base/location.h" 13 #include "base/logging.h" 14 #include "base/message_loop/message_loop.h" 15 #include "base/stl_util.h" 16 #include "mojo/system/message_in_transit.h" 17 #include "mojo/system/transport_data.h" 18 19 namespace mojo { 20 namespace system { 21 22 const size_t kReadSize = 4096; 23 24 // RawChannel::ReadBuffer ------------------------------------------------------ 25 26 RawChannel::ReadBuffer::ReadBuffer() : buffer_(kReadSize), num_valid_bytes_(0) { 27 } 28 29 RawChannel::ReadBuffer::~ReadBuffer() { 30 } 31 32 void RawChannel::ReadBuffer::GetBuffer(char** addr, size_t* size) { 33 DCHECK_GE(buffer_.size(), num_valid_bytes_ + kReadSize); 34 *addr = &buffer_[0] + num_valid_bytes_; 35 *size = kReadSize; 36 } 37 38 // RawChannel::WriteBuffer ----------------------------------------------------- 39 40 RawChannel::WriteBuffer::WriteBuffer(size_t serialized_platform_handle_size) 41 : serialized_platform_handle_size_(serialized_platform_handle_size), 42 platform_handles_offset_(0), 43 data_offset_(0) { 44 } 45 46 RawChannel::WriteBuffer::~WriteBuffer() { 47 STLDeleteElements(&message_queue_); 48 } 49 50 bool RawChannel::WriteBuffer::HavePlatformHandlesToSend() const { 51 if (message_queue_.empty()) 52 return false; 53 54 const TransportData* transport_data = 55 message_queue_.front()->transport_data(); 56 if (!transport_data) 57 return false; 58 59 const embedder::PlatformHandleVector* all_platform_handles = 60 transport_data->platform_handles(); 61 if (!all_platform_handles) { 62 DCHECK_EQ(platform_handles_offset_, 0u); 63 return false; 64 } 65 if (platform_handles_offset_ >= all_platform_handles->size()) { 66 DCHECK_EQ(platform_handles_offset_, all_platform_handles->size()); 67 return false; 68 } 69 70 return true; 71 } 72 73 void RawChannel::WriteBuffer::GetPlatformHandlesToSend( 74 size_t* num_platform_handles, 75 embedder::PlatformHandle** platform_handles, 76 void** serialization_data) { 77 DCHECK(HavePlatformHandlesToSend()); 78 79 TransportData* transport_data = message_queue_.front()->transport_data(); 80 embedder::PlatformHandleVector* all_platform_handles = 81 transport_data->platform_handles(); 82 *num_platform_handles = 83 all_platform_handles->size() - platform_handles_offset_; 84 *platform_handles = &(*all_platform_handles)[platform_handles_offset_]; 85 size_t serialization_data_offset = 86 transport_data->platform_handle_table_offset(); 87 DCHECK_GT(serialization_data_offset, 0u); 88 serialization_data_offset += 89 platform_handles_offset_ * serialized_platform_handle_size_; 90 *serialization_data = 91 static_cast<char*>(transport_data->buffer()) + serialization_data_offset; 92 } 93 94 void RawChannel::WriteBuffer::GetBuffers(std::vector<Buffer>* buffers) const { 95 buffers->clear(); 96 97 if (message_queue_.empty()) 98 return; 99 100 MessageInTransit* message = message_queue_.front(); 101 DCHECK_LT(data_offset_, message->total_size()); 102 size_t bytes_to_write = message->total_size() - data_offset_; 103 104 size_t transport_data_buffer_size = 105 message->transport_data() ? message->transport_data()->buffer_size() : 0; 106 107 if (!transport_data_buffer_size) { 108 // Only write from the main buffer. 109 DCHECK_LT(data_offset_, message->main_buffer_size()); 110 DCHECK_LE(bytes_to_write, message->main_buffer_size()); 111 Buffer buffer = { 112 static_cast<const char*>(message->main_buffer()) + data_offset_, 113 bytes_to_write}; 114 buffers->push_back(buffer); 115 return; 116 } 117 118 if (data_offset_ >= message->main_buffer_size()) { 119 // Only write from the transport data buffer. 120 DCHECK_LT(data_offset_ - message->main_buffer_size(), 121 transport_data_buffer_size); 122 DCHECK_LE(bytes_to_write, transport_data_buffer_size); 123 Buffer buffer = { 124 static_cast<const char*>(message->transport_data()->buffer()) + 125 (data_offset_ - message->main_buffer_size()), 126 bytes_to_write}; 127 buffers->push_back(buffer); 128 return; 129 } 130 131 // TODO(vtl): We could actually send out buffers from multiple messages, with 132 // the "stopping" condition being reaching a message with platform handles 133 // attached. 134 135 // Write from both buffers. 136 DCHECK_EQ( 137 bytes_to_write, 138 message->main_buffer_size() - data_offset_ + transport_data_buffer_size); 139 Buffer buffer1 = { 140 static_cast<const char*>(message->main_buffer()) + data_offset_, 141 message->main_buffer_size() - data_offset_}; 142 buffers->push_back(buffer1); 143 Buffer buffer2 = { 144 static_cast<const char*>(message->transport_data()->buffer()), 145 transport_data_buffer_size}; 146 buffers->push_back(buffer2); 147 } 148 149 // RawChannel ------------------------------------------------------------------ 150 151 RawChannel::RawChannel() 152 : message_loop_for_io_(nullptr), 153 delegate_(nullptr), 154 read_stopped_(false), 155 write_stopped_(false), 156 weak_ptr_factory_(this) { 157 } 158 159 RawChannel::~RawChannel() { 160 DCHECK(!read_buffer_); 161 DCHECK(!write_buffer_); 162 163 // No need to take the |write_lock_| here -- if there are still weak pointers 164 // outstanding, then we're hosed anyway (since we wouldn't be able to 165 // invalidate them cleanly, since we might not be on the I/O thread). 166 DCHECK(!weak_ptr_factory_.HasWeakPtrs()); 167 } 168 169 bool RawChannel::Init(Delegate* delegate) { 170 DCHECK(delegate); 171 172 DCHECK(!delegate_); 173 delegate_ = delegate; 174 175 CHECK_EQ(base::MessageLoop::current()->type(), base::MessageLoop::TYPE_IO); 176 DCHECK(!message_loop_for_io_); 177 message_loop_for_io_ = 178 static_cast<base::MessageLoopForIO*>(base::MessageLoop::current()); 179 180 // No need to take the lock. No one should be using us yet. 181 DCHECK(!read_buffer_); 182 read_buffer_.reset(new ReadBuffer); 183 DCHECK(!write_buffer_); 184 write_buffer_.reset(new WriteBuffer(GetSerializedPlatformHandleSize())); 185 186 if (!OnInit()) { 187 delegate_ = nullptr; 188 message_loop_for_io_ = nullptr; 189 read_buffer_.reset(); 190 write_buffer_.reset(); 191 return false; 192 } 193 194 IOResult io_result = ScheduleRead(); 195 if (io_result != IO_PENDING) { 196 // This will notify the delegate about the read failure. Although we're on 197 // the I/O thread, don't call it in the nested context. 198 message_loop_for_io_->PostTask(FROM_HERE, 199 base::Bind(&RawChannel::OnReadCompleted, 200 weak_ptr_factory_.GetWeakPtr(), 201 io_result, 202 0)); 203 } 204 205 // ScheduleRead() failure is treated as a read failure (by notifying the 206 // delegate), not as an init failure. 207 return true; 208 } 209 210 void RawChannel::Shutdown() { 211 DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io_); 212 213 base::AutoLock locker(write_lock_); 214 215 LOG_IF(WARNING, !write_buffer_->message_queue_.empty()) 216 << "Shutting down RawChannel with write buffer nonempty"; 217 218 // Reset the delegate so that it won't receive further calls. 219 delegate_ = nullptr; 220 read_stopped_ = true; 221 write_stopped_ = true; 222 weak_ptr_factory_.InvalidateWeakPtrs(); 223 224 OnShutdownNoLock(read_buffer_.Pass(), write_buffer_.Pass()); 225 } 226 227 // Reminder: This must be thread-safe. 228 bool RawChannel::WriteMessage(scoped_ptr<MessageInTransit> message) { 229 DCHECK(message); 230 231 base::AutoLock locker(write_lock_); 232 if (write_stopped_) 233 return false; 234 235 if (!write_buffer_->message_queue_.empty()) { 236 EnqueueMessageNoLock(message.Pass()); 237 return true; 238 } 239 240 EnqueueMessageNoLock(message.Pass()); 241 DCHECK_EQ(write_buffer_->data_offset_, 0u); 242 243 size_t platform_handles_written = 0; 244 size_t bytes_written = 0; 245 IOResult io_result = WriteNoLock(&platform_handles_written, &bytes_written); 246 if (io_result == IO_PENDING) 247 return true; 248 249 bool result = OnWriteCompletedNoLock( 250 io_result, platform_handles_written, bytes_written); 251 if (!result) { 252 // Even if we're on the I/O thread, don't call |OnError()| in the nested 253 // context. 254 message_loop_for_io_->PostTask(FROM_HERE, 255 base::Bind(&RawChannel::CallOnError, 256 weak_ptr_factory_.GetWeakPtr(), 257 Delegate::ERROR_WRITE)); 258 } 259 260 return result; 261 } 262 263 // Reminder: This must be thread-safe. 264 bool RawChannel::IsWriteBufferEmpty() { 265 base::AutoLock locker(write_lock_); 266 return write_buffer_->message_queue_.empty(); 267 } 268 269 void RawChannel::OnReadCompleted(IOResult io_result, size_t bytes_read) { 270 DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io_); 271 272 if (read_stopped_) { 273 NOTREACHED(); 274 return; 275 } 276 277 // Keep reading data in a loop, and dispatch messages if enough data is 278 // received. Exit the loop if any of the following happens: 279 // - one or more messages were dispatched; 280 // - the last read failed, was a partial read or would block; 281 // - |Shutdown()| was called. 282 do { 283 switch (io_result) { 284 case IO_SUCCEEDED: 285 break; 286 case IO_FAILED_SHUTDOWN: 287 case IO_FAILED_BROKEN: 288 case IO_FAILED_UNKNOWN: 289 read_stopped_ = true; 290 CallOnError(ReadIOResultToError(io_result)); 291 return; 292 case IO_PENDING: 293 NOTREACHED(); 294 return; 295 } 296 297 read_buffer_->num_valid_bytes_ += bytes_read; 298 299 // Dispatch all the messages that we can. 300 bool did_dispatch_message = false; 301 // Tracks the offset of the first undispatched message in |read_buffer_|. 302 // Currently, we copy data to ensure that this is zero at the beginning. 303 size_t read_buffer_start = 0; 304 size_t remaining_bytes = read_buffer_->num_valid_bytes_; 305 size_t message_size; 306 // Note that we rely on short-circuit evaluation here: 307 // - |read_buffer_start| may be an invalid index into 308 // |read_buffer_->buffer_| if |remaining_bytes| is zero. 309 // - |message_size| is only valid if |GetNextMessageSize()| returns true. 310 // TODO(vtl): Use |message_size| more intelligently (e.g., to request the 311 // next read). 312 // TODO(vtl): Validate that |message_size| is sane. 313 while (remaining_bytes > 0 && MessageInTransit::GetNextMessageSize( 314 &read_buffer_->buffer_[read_buffer_start], 315 remaining_bytes, 316 &message_size) && 317 remaining_bytes >= message_size) { 318 MessageInTransit::View message_view( 319 message_size, &read_buffer_->buffer_[read_buffer_start]); 320 DCHECK_EQ(message_view.total_size(), message_size); 321 322 const char* error_message = nullptr; 323 if (!message_view.IsValid(GetSerializedPlatformHandleSize(), 324 &error_message)) { 325 DCHECK(error_message); 326 LOG(ERROR) << "Received invalid message: " << error_message; 327 read_stopped_ = true; 328 CallOnError(Delegate::ERROR_READ_BAD_MESSAGE); 329 return; 330 } 331 332 if (message_view.type() == MessageInTransit::kTypeRawChannel) { 333 if (!OnReadMessageForRawChannel(message_view)) { 334 read_stopped_ = true; 335 CallOnError(Delegate::ERROR_READ_BAD_MESSAGE); 336 return; 337 } 338 } else { 339 embedder::ScopedPlatformHandleVectorPtr platform_handles; 340 if (message_view.transport_data_buffer()) { 341 size_t num_platform_handles; 342 const void* platform_handle_table; 343 TransportData::GetPlatformHandleTable( 344 message_view.transport_data_buffer(), 345 &num_platform_handles, 346 &platform_handle_table); 347 348 if (num_platform_handles > 0) { 349 platform_handles = 350 GetReadPlatformHandles(num_platform_handles, 351 platform_handle_table).Pass(); 352 if (!platform_handles) { 353 LOG(ERROR) << "Invalid number of platform handles received"; 354 read_stopped_ = true; 355 CallOnError(Delegate::ERROR_READ_BAD_MESSAGE); 356 return; 357 } 358 } 359 } 360 361 // TODO(vtl): In the case that we aren't expecting any platform handles, 362 // for the POSIX implementation, we should confirm that none are stored. 363 364 // Dispatch the message. 365 DCHECK(delegate_); 366 delegate_->OnReadMessage(message_view, platform_handles.Pass()); 367 if (read_stopped_) { 368 // |Shutdown()| was called in |OnReadMessage()|. 369 // TODO(vtl): Add test for this case. 370 return; 371 } 372 } 373 374 did_dispatch_message = true; 375 376 // Update our state. 377 read_buffer_start += message_size; 378 remaining_bytes -= message_size; 379 } 380 381 if (read_buffer_start > 0) { 382 // Move data back to start. 383 read_buffer_->num_valid_bytes_ = remaining_bytes; 384 if (read_buffer_->num_valid_bytes_ > 0) { 385 memmove(&read_buffer_->buffer_[0], 386 &read_buffer_->buffer_[read_buffer_start], 387 remaining_bytes); 388 } 389 read_buffer_start = 0; 390 } 391 392 if (read_buffer_->buffer_.size() - read_buffer_->num_valid_bytes_ < 393 kReadSize) { 394 // Use power-of-2 buffer sizes. 395 // TODO(vtl): Make sure the buffer doesn't get too large (and enforce the 396 // maximum message size to whatever extent necessary). 397 // TODO(vtl): We may often be able to peek at the header and get the real 398 // required extra space (which may be much bigger than |kReadSize|). 399 size_t new_size = std::max(read_buffer_->buffer_.size(), kReadSize); 400 while (new_size < read_buffer_->num_valid_bytes_ + kReadSize) 401 new_size *= 2; 402 403 // TODO(vtl): It's suboptimal to zero out the fresh memory. 404 read_buffer_->buffer_.resize(new_size, 0); 405 } 406 407 // (1) If we dispatched any messages, stop reading for now (and let the 408 // message loop do its thing for another round). 409 // TODO(vtl): Is this the behavior we want? (Alternatives: i. Dispatch only 410 // a single message. Risks: slower, more complex if we want to avoid lots of 411 // copying. ii. Keep reading until there's no more data and dispatch all the 412 // messages we can. Risks: starvation of other users of the message loop.) 413 // (2) If we didn't max out |kReadSize|, stop reading for now. 414 bool schedule_for_later = did_dispatch_message || bytes_read < kReadSize; 415 bytes_read = 0; 416 io_result = schedule_for_later ? ScheduleRead() : Read(&bytes_read); 417 } while (io_result != IO_PENDING); 418 } 419 420 void RawChannel::OnWriteCompleted(IOResult io_result, 421 size_t platform_handles_written, 422 size_t bytes_written) { 423 DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io_); 424 DCHECK_NE(io_result, IO_PENDING); 425 426 bool did_fail = false; 427 { 428 base::AutoLock locker(write_lock_); 429 DCHECK_EQ(write_stopped_, write_buffer_->message_queue_.empty()); 430 431 if (write_stopped_) { 432 NOTREACHED(); 433 return; 434 } 435 436 did_fail = !OnWriteCompletedNoLock( 437 io_result, platform_handles_written, bytes_written); 438 } 439 440 if (did_fail) 441 CallOnError(Delegate::ERROR_WRITE); 442 } 443 444 void RawChannel::EnqueueMessageNoLock(scoped_ptr<MessageInTransit> message) { 445 write_lock_.AssertAcquired(); 446 write_buffer_->message_queue_.push_back(message.release()); 447 } 448 449 bool RawChannel::OnReadMessageForRawChannel( 450 const MessageInTransit::View& message_view) { 451 // No non-implementation specific |RawChannel| control messages. 452 LOG(ERROR) << "Invalid control message (subtype " << message_view.subtype() 453 << ")"; 454 return false; 455 } 456 457 // static 458 RawChannel::Delegate::Error RawChannel::ReadIOResultToError( 459 IOResult io_result) { 460 switch (io_result) { 461 case IO_FAILED_SHUTDOWN: 462 return Delegate::ERROR_READ_SHUTDOWN; 463 case IO_FAILED_BROKEN: 464 return Delegate::ERROR_READ_BROKEN; 465 case IO_FAILED_UNKNOWN: 466 return Delegate::ERROR_READ_UNKNOWN; 467 case IO_SUCCEEDED: 468 case IO_PENDING: 469 NOTREACHED(); 470 break; 471 } 472 return Delegate::ERROR_READ_UNKNOWN; 473 } 474 475 void RawChannel::CallOnError(Delegate::Error error) { 476 DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io_); 477 // TODO(vtl): Add a "write_lock_.AssertNotAcquired()"? 478 if (delegate_) 479 delegate_->OnError(error); 480 } 481 482 bool RawChannel::OnWriteCompletedNoLock(IOResult io_result, 483 size_t platform_handles_written, 484 size_t bytes_written) { 485 write_lock_.AssertAcquired(); 486 487 DCHECK(!write_stopped_); 488 DCHECK(!write_buffer_->message_queue_.empty()); 489 490 if (io_result == IO_SUCCEEDED) { 491 write_buffer_->platform_handles_offset_ += platform_handles_written; 492 write_buffer_->data_offset_ += bytes_written; 493 494 MessageInTransit* message = write_buffer_->message_queue_.front(); 495 if (write_buffer_->data_offset_ >= message->total_size()) { 496 // Complete write. 497 CHECK_EQ(write_buffer_->data_offset_, message->total_size()); 498 write_buffer_->message_queue_.pop_front(); 499 delete message; 500 write_buffer_->platform_handles_offset_ = 0; 501 write_buffer_->data_offset_ = 0; 502 503 if (write_buffer_->message_queue_.empty()) 504 return true; 505 } 506 507 // Schedule the next write. 508 io_result = ScheduleWriteNoLock(); 509 if (io_result == IO_PENDING) 510 return true; 511 DCHECK_NE(io_result, IO_SUCCEEDED); 512 } 513 514 write_stopped_ = true; 515 STLDeleteElements(&write_buffer_->message_queue_); 516 write_buffer_->platform_handles_offset_ = 0; 517 write_buffer_->data_offset_ = 0; 518 return false; 519 } 520 521 } // namespace system 522 } // namespace mojo 523