1 // Copyright 2015 The Chromium OS Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include <brillo/streams/stream_utils.h> 6 7 #include <limits> 8 9 #include <base/bind.h> 10 #include <brillo/message_loops/message_loop.h> 11 #include <brillo/streams/stream_errors.h> 12 13 namespace brillo { 14 namespace stream_utils { 15 16 namespace { 17 18 // Status of asynchronous CopyData operation. 19 struct CopyDataState { 20 brillo::StreamPtr in_stream; 21 brillo::StreamPtr out_stream; 22 std::vector<uint8_t> buffer; 23 uint64_t remaining_to_copy; 24 uint64_t size_copied; 25 CopyDataSuccessCallback success_callback; 26 CopyDataErrorCallback error_callback; 27 }; 28 29 // Async CopyData I/O error callback. 30 void OnCopyDataError(const std::shared_ptr<CopyDataState>& state, 31 const brillo::Error* error) { 32 state->error_callback.Run(std::move(state->in_stream), 33 std::move(state->out_stream), error); 34 } 35 36 // Forward declaration. 37 void PerformRead(const std::shared_ptr<CopyDataState>& state); 38 39 // Callback from read operation for CopyData. Writes the read data to the output 40 // stream and invokes PerformRead when done to restart the copy cycle. 41 void PerformWrite(const std::shared_ptr<CopyDataState>& state, size_t size) { 42 if (size == 0) { 43 state->success_callback.Run(std::move(state->in_stream), 44 std::move(state->out_stream), 45 state->size_copied); 46 return; 47 } 48 state->size_copied += size; 49 CHECK_GE(state->remaining_to_copy, size); 50 state->remaining_to_copy -= size; 51 52 brillo::ErrorPtr error; 53 bool success = state->out_stream->WriteAllAsync( 54 state->buffer.data(), size, base::Bind(&PerformRead, state), 55 base::Bind(&OnCopyDataError, state), &error); 56 57 if (!success) 58 OnCopyDataError(state, error.get()); 59 } 60 61 // Performs the read part of asynchronous CopyData operation. Reads the data 62 // from input stream and invokes PerformWrite when done to write the data to 63 // the output stream. 64 void PerformRead(const std::shared_ptr<CopyDataState>& state) { 65 brillo::ErrorPtr error; 66 const uint64_t buffer_size = state->buffer.size(); 67 // |buffer_size| is guaranteed to fit in size_t, so |size_to_read| value will 68 // also not overflow size_t, so the static_cast below is safe. 69 size_t size_to_read = 70 static_cast<size_t>(std::min(buffer_size, state->remaining_to_copy)); 71 if (size_to_read == 0) 72 return PerformWrite(state, 0); // Nothing more to read. Finish operation. 73 bool success = state->in_stream->ReadAsync( 74 state->buffer.data(), size_to_read, base::Bind(PerformWrite, state), 75 base::Bind(OnCopyDataError, state), &error); 76 77 if (!success) 78 OnCopyDataError(state, error.get()); 79 } 80 81 } // anonymous namespace 82 83 bool ErrorStreamClosed(const base::Location& location, 84 ErrorPtr* error) { 85 Error::AddTo(error, 86 location, 87 errors::stream::kDomain, 88 errors::stream::kStreamClosed, 89 "Stream is closed"); 90 return false; 91 } 92 93 bool ErrorOperationNotSupported(const base::Location& location, 94 ErrorPtr* error) { 95 Error::AddTo(error, 96 location, 97 errors::stream::kDomain, 98 errors::stream::kOperationNotSupported, 99 "Stream operation not supported"); 100 return false; 101 } 102 103 bool ErrorReadPastEndOfStream(const base::Location& location, 104 ErrorPtr* error) { 105 Error::AddTo(error, 106 location, 107 errors::stream::kDomain, 108 errors::stream::kPartialData, 109 "Reading past the end of stream"); 110 return false; 111 } 112 113 bool ErrorOperationTimeout(const base::Location& location, 114 ErrorPtr* error) { 115 Error::AddTo(error, 116 location, 117 errors::stream::kDomain, 118 errors::stream::kTimeout, 119 "Operation timed out"); 120 return false; 121 } 122 123 bool CheckInt64Overflow(const base::Location& location, 124 uint64_t position, 125 int64_t offset, 126 ErrorPtr* error) { 127 if (offset < 0) { 128 // Subtracting the offset. Make sure we do not underflow. 129 uint64_t unsigned_offset = static_cast<uint64_t>(-offset); 130 if (position >= unsigned_offset) 131 return true; 132 } else { 133 // Adding the offset. Make sure we do not overflow unsigned 64 bits first. 134 if (position <= std::numeric_limits<uint64_t>::max() - offset) { 135 // We definitely will not overflow the unsigned 64 bit integer. 136 // Now check that we end up within the limits of signed 64 bit integer. 137 uint64_t new_position = position + offset; 138 uint64_t max = std::numeric_limits<int64_t>::max(); 139 if (new_position <= max) 140 return true; 141 } 142 } 143 Error::AddTo(error, 144 location, 145 errors::stream::kDomain, 146 errors::stream::kInvalidParameter, 147 "The stream offset value is out of range"); 148 return false; 149 } 150 151 bool CalculateStreamPosition(const base::Location& location, 152 int64_t offset, 153 Stream::Whence whence, 154 uint64_t current_position, 155 uint64_t stream_size, 156 uint64_t* new_position, 157 ErrorPtr* error) { 158 uint64_t pos = 0; 159 switch (whence) { 160 case Stream::Whence::FROM_BEGIN: 161 pos = 0; 162 break; 163 164 case Stream::Whence::FROM_CURRENT: 165 pos = current_position; 166 break; 167 168 case Stream::Whence::FROM_END: 169 pos = stream_size; 170 break; 171 172 default: 173 Error::AddTo(error, 174 location, 175 errors::stream::kDomain, 176 errors::stream::kInvalidParameter, 177 "Invalid stream position whence"); 178 return false; 179 } 180 181 if (!CheckInt64Overflow(location, pos, offset, error)) 182 return false; 183 184 *new_position = static_cast<uint64_t>(pos + offset); 185 return true; 186 } 187 188 void CopyData(StreamPtr in_stream, 189 StreamPtr out_stream, 190 const CopyDataSuccessCallback& success_callback, 191 const CopyDataErrorCallback& error_callback) { 192 CopyData(std::move(in_stream), std::move(out_stream), 193 std::numeric_limits<uint64_t>::max(), 4096, success_callback, 194 error_callback); 195 } 196 197 void CopyData(StreamPtr in_stream, 198 StreamPtr out_stream, 199 uint64_t max_size_to_copy, 200 size_t buffer_size, 201 const CopyDataSuccessCallback& success_callback, 202 const CopyDataErrorCallback& error_callback) { 203 auto state = std::make_shared<CopyDataState>(); 204 state->in_stream = std::move(in_stream); 205 state->out_stream = std::move(out_stream); 206 state->buffer.resize(buffer_size); 207 state->remaining_to_copy = max_size_to_copy; 208 state->size_copied = 0; 209 state->success_callback = success_callback; 210 state->error_callback = error_callback; 211 brillo::MessageLoop::current()->PostTask(FROM_HERE, 212 base::Bind(&PerformRead, state)); 213 } 214 215 } // namespace stream_utils 216 } // namespace brillo 217