Home | History | Annotate | Download | only in streams
      1 // Copyright 2015 The Chromium OS Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include <brillo/streams/stream_utils.h>
      6 
      7 #include <limits>
      8 
      9 #include <base/bind.h>
     10 #include <brillo/message_loops/message_loop.h>
     11 #include <brillo/streams/stream_errors.h>
     12 
     13 namespace brillo {
     14 namespace stream_utils {
     15 
     16 namespace {
     17 
     18 // Status of asynchronous CopyData operation.
     19 struct CopyDataState {
     20   brillo::StreamPtr in_stream;
     21   brillo::StreamPtr out_stream;
     22   std::vector<uint8_t> buffer;
     23   uint64_t remaining_to_copy;
     24   uint64_t size_copied;
     25   CopyDataSuccessCallback success_callback;
     26   CopyDataErrorCallback error_callback;
     27 };
     28 
     29 // Async CopyData I/O error callback.
     30 void OnCopyDataError(const std::shared_ptr<CopyDataState>& state,
     31                      const brillo::Error* error) {
     32   state->error_callback.Run(std::move(state->in_stream),
     33                             std::move(state->out_stream), error);
     34 }
     35 
     36 // Forward declaration.
     37 void PerformRead(const std::shared_ptr<CopyDataState>& state);
     38 
     39 // Callback from read operation for CopyData. Writes the read data to the output
     40 // stream and invokes PerformRead when done to restart the copy cycle.
     41 void PerformWrite(const std::shared_ptr<CopyDataState>& state, size_t size) {
     42   if (size == 0) {
     43     state->success_callback.Run(std::move(state->in_stream),
     44                                 std::move(state->out_stream),
     45                                 state->size_copied);
     46     return;
     47   }
     48   state->size_copied += size;
     49   CHECK_GE(state->remaining_to_copy, size);
     50   state->remaining_to_copy -= size;
     51 
     52   brillo::ErrorPtr error;
     53   bool success = state->out_stream->WriteAllAsync(
     54       state->buffer.data(), size, base::Bind(&PerformRead, state),
     55       base::Bind(&OnCopyDataError, state), &error);
     56 
     57   if (!success)
     58     OnCopyDataError(state, error.get());
     59 }
     60 
     61 // Performs the read part of asynchronous CopyData operation. Reads the data
     62 // from input stream and invokes PerformWrite when done to write the data to
     63 // the output stream.
     64 void PerformRead(const std::shared_ptr<CopyDataState>& state) {
     65   brillo::ErrorPtr error;
     66   const uint64_t buffer_size = state->buffer.size();
     67   // |buffer_size| is guaranteed to fit in size_t, so |size_to_read| value will
     68   // also not overflow size_t, so the static_cast below is safe.
     69   size_t size_to_read =
     70       static_cast<size_t>(std::min(buffer_size, state->remaining_to_copy));
     71   if (size_to_read == 0)
     72     return PerformWrite(state, 0);  // Nothing more to read. Finish operation.
     73   bool success = state->in_stream->ReadAsync(
     74       state->buffer.data(), size_to_read, base::Bind(PerformWrite, state),
     75       base::Bind(OnCopyDataError, state), &error);
     76 
     77   if (!success)
     78     OnCopyDataError(state, error.get());
     79 }
     80 
     81 }  // anonymous namespace
     82 
     83 bool ErrorStreamClosed(const tracked_objects::Location& location,
     84                        ErrorPtr* error) {
     85   Error::AddTo(error,
     86                location,
     87                errors::stream::kDomain,
     88                errors::stream::kStreamClosed,
     89                "Stream is closed");
     90   return false;
     91 }
     92 
     93 bool ErrorOperationNotSupported(const tracked_objects::Location& location,
     94                                 ErrorPtr* error) {
     95   Error::AddTo(error,
     96                location,
     97                errors::stream::kDomain,
     98                errors::stream::kOperationNotSupported,
     99                "Stream operation not supported");
    100   return false;
    101 }
    102 
    103 bool ErrorReadPastEndOfStream(const tracked_objects::Location& location,
    104                               ErrorPtr* error) {
    105   Error::AddTo(error,
    106                location,
    107                errors::stream::kDomain,
    108                errors::stream::kPartialData,
    109                "Reading past the end of stream");
    110   return false;
    111 }
    112 
    113 bool ErrorOperationTimeout(const tracked_objects::Location& location,
    114                            ErrorPtr* error) {
    115   Error::AddTo(error,
    116                location,
    117                errors::stream::kDomain,
    118                errors::stream::kTimeout,
    119                "Operation timed out");
    120   return false;
    121 }
    122 
    123 bool CheckInt64Overflow(const tracked_objects::Location& location,
    124                         uint64_t position,
    125                         int64_t offset,
    126                         ErrorPtr* error) {
    127   if (offset < 0) {
    128     // Subtracting the offset. Make sure we do not underflow.
    129     uint64_t unsigned_offset = static_cast<uint64_t>(-offset);
    130     if (position >= unsigned_offset)
    131       return true;
    132   } else {
    133     // Adding the offset. Make sure we do not overflow unsigned 64 bits first.
    134     if (position <= std::numeric_limits<uint64_t>::max() - offset) {
    135       // We definitely will not overflow the unsigned 64 bit integer.
    136       // Now check that we end up within the limits of signed 64 bit integer.
    137       uint64_t new_position = position + offset;
    138       uint64_t max = std::numeric_limits<int64_t>::max();
    139       if (new_position <= max)
    140         return true;
    141     }
    142   }
    143   Error::AddTo(error,
    144                location,
    145                errors::stream::kDomain,
    146                errors::stream::kInvalidParameter,
    147                "The stream offset value is out of range");
    148   return false;
    149 }
    150 
    151 bool CalculateStreamPosition(const tracked_objects::Location& location,
    152                              int64_t offset,
    153                              Stream::Whence whence,
    154                              uint64_t current_position,
    155                              uint64_t stream_size,
    156                              uint64_t* new_position,
    157                              ErrorPtr* error) {
    158   uint64_t pos = 0;
    159   switch (whence) {
    160     case Stream::Whence::FROM_BEGIN:
    161       pos = 0;
    162       break;
    163 
    164     case Stream::Whence::FROM_CURRENT:
    165       pos = current_position;
    166       break;
    167 
    168     case Stream::Whence::FROM_END:
    169       pos = stream_size;
    170       break;
    171 
    172     default:
    173       Error::AddTo(error,
    174                    location,
    175                    errors::stream::kDomain,
    176                    errors::stream::kInvalidParameter,
    177                    "Invalid stream position whence");
    178       return false;
    179   }
    180 
    181   if (!CheckInt64Overflow(location, pos, offset, error))
    182     return false;
    183 
    184   *new_position = static_cast<uint64_t>(pos + offset);
    185   return true;
    186 }
    187 
    188 void CopyData(StreamPtr in_stream,
    189               StreamPtr out_stream,
    190               const CopyDataSuccessCallback& success_callback,
    191               const CopyDataErrorCallback& error_callback) {
    192   CopyData(std::move(in_stream), std::move(out_stream),
    193            std::numeric_limits<uint64_t>::max(), 4096, success_callback,
    194            error_callback);
    195 }
    196 
    197 void CopyData(StreamPtr in_stream,
    198               StreamPtr out_stream,
    199               uint64_t max_size_to_copy,
    200               size_t buffer_size,
    201               const CopyDataSuccessCallback& success_callback,
    202               const CopyDataErrorCallback& error_callback) {
    203   auto state = std::make_shared<CopyDataState>();
    204   state->in_stream = std::move(in_stream);
    205   state->out_stream = std::move(out_stream);
    206   state->buffer.resize(buffer_size);
    207   state->remaining_to_copy = max_size_to_copy;
    208   state->size_copied = 0;
    209   state->success_callback = success_callback;
    210   state->error_callback = error_callback;
    211   brillo::MessageLoop::current()->PostTask(FROM_HERE,
    212                                              base::Bind(&PerformRead, state));
    213 }
    214 
    215 }  // namespace stream_utils
    216 }  // namespace brillo
    217