1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/kernels/queue_base.h" 17 18 #include <vector> 19 #include "tensorflow/core/framework/node_def.pb.h" 20 #include "tensorflow/core/framework/tensor_shape.h" 21 #include "tensorflow/core/kernels/batch_util.h" 22 #include "tensorflow/core/lib/core/errors.h" 23 #include "tensorflow/core/platform/mutex.h" 24 #include "tensorflow/core/platform/types.h" 25 26 namespace tensorflow { 27 28 namespace { 29 30 template <DataType DT> 31 Status HandleSliceToElement(const Tensor& parent, Tensor* element, 32 int64 index) { 33 typedef typename EnumToDataType<DT>::Type T; 34 DCHECK_NE(parent.dim_size(0), 0); 35 DCHECK_GE(index, 0); 36 if (element->NumElements() != (parent.NumElements() / parent.dim_size(0))) { 37 TensorShape chip_shape = parent.shape(); 38 chip_shape.RemoveDim(0); 39 return errors::Internal( 40 "HandleSliceToElement Cannot copy slice: number of elements does not " 41 "match. Shapes are: [element]: ", 42 element->shape().DebugString(), 43 ", [parent slice]: ", chip_shape.DebugString()); 44 } 45 auto parent_as_matrix = parent.flat_outer_dims<T>(); 46 element->flat<T>() = parent_as_matrix.chip(index, 0); 47 return Status::OK(); 48 } 49 50 } // namespace 51 52 QueueBase::QueueBase(int32 capacity, const DataTypeVector& component_dtypes, 53 const std::vector<TensorShape>& component_shapes, 54 const string& name) 55 : capacity_(capacity), 56 component_dtypes_(component_dtypes), 57 component_shapes_(component_shapes), 58 name_(name), 59 closed_(false) {} 60 61 QueueBase::~QueueBase() {} 62 63 Status QueueBase::ValidateTupleCommon(const Tuple& tuple) const { 64 if (tuple.size() != static_cast<size_t>(num_components())) { 65 return errors::InvalidArgument( 66 "Wrong number of components in tuple. Expected ", num_components(), 67 ", got ", tuple.size()); 68 } 69 for (size_t i = 0; i < tuple.size(); ++i) { 70 if (tuple[i].dtype() != component_dtypes_[i]) { 71 return errors::InvalidArgument( 72 "Type mismatch in tuple component ", i, ". Expected ", 73 DataTypeString(component_dtypes_[i]), ", got ", 74 DataTypeString(tuple[i].dtype())); 75 } 76 } 77 return Status::OK(); 78 } 79 80 // static 81 string QueueBase::ShapeListString(const gtl::ArraySlice<TensorShape>& shapes) { 82 string result = "["; 83 bool first = true; 84 for (const TensorShape& shape : shapes) { 85 strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString()); 86 first = false; 87 } 88 strings::StrAppend(&result, "]"); 89 return result; 90 } 91 92 Status QueueBase::MatchesNodeDefOp(const NodeDef& node_def, 93 const string& op) const { 94 if (node_def.op() != op) { 95 return errors::InvalidArgument("Shared queue '", name_, "' has type '", op, 96 "' that does not match type of Node '", 97 node_def.name(), "': ", node_def.op()); 98 } 99 return Status::OK(); 100 } 101 102 Status QueueBase::MatchesNodeDefCapacity(const NodeDef& node_def, 103 int32 capacity) const { 104 int32 requested_capacity = -1; 105 TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "capacity", &requested_capacity)); 106 if (requested_capacity < 0) requested_capacity = kUnbounded; 107 if (requested_capacity != capacity) { 108 return errors::InvalidArgument("Shared queue '", name_, "' has capacity ", 109 capacity, " but requested capacity was ", 110 requested_capacity); 111 } 112 return Status::OK(); 113 } 114 115 Status QueueBase::MatchesNodeDefTypes(const NodeDef& node_def) const { 116 DataTypeVector requested_dtypes; 117 TF_RETURN_IF_ERROR( 118 GetNodeAttr(node_def, "component_types", &requested_dtypes)); 119 if (requested_dtypes != component_dtypes_) { 120 return errors::InvalidArgument("Shared queue '", name_, 121 "' has component types ", 122 DataTypeSliceString(component_dtypes_), 123 " but requested component types were ", 124 DataTypeSliceString(requested_dtypes)); 125 } 126 return Status::OK(); 127 } 128 129 Status QueueBase::MatchesNodeDefShapes(const NodeDef& node_def) const { 130 std::vector<TensorShape> requested_shapes; 131 TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "shapes", &requested_shapes)); 132 if (requested_shapes != component_shapes_) { 133 return errors::InvalidArgument("Shared queue '", name_, 134 "' has component shapes ", 135 ShapeListString(component_shapes_), 136 " but requested component shapes were ", 137 ShapeListString(requested_shapes)); 138 } 139 return Status::OK(); 140 } 141 142 // TODO(mrry): If these checks become a bottleneck, find a way to 143 // reduce the number of times that they are called. 144 Status QueueBase::ValidateTuple(const Tuple& tuple) { 145 TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple)); 146 if (specified_shapes()) { 147 for (size_t i = 0; i < tuple.size(); ++i) { 148 if (!component_shapes_[i].IsSameSize(tuple[i].shape())) { 149 return errors::InvalidArgument( 150 "Shape mismatch in tuple component ", i, ". Expected ", 151 component_shapes_[i].DebugString(), ", got ", 152 tuple[i].shape().DebugString()); 153 } 154 } 155 } 156 return Status::OK(); 157 } 158 159 // TODO(mrry): If these checks become a bottleneck, find a way to 160 // reduce the number of times that they are called. 161 Status QueueBase::ValidateManyTuple(const Tuple& tuple) { 162 TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple)); 163 const int64 batch_size = tuple[0].dim_size(0); 164 if (specified_shapes()) { 165 for (size_t i = 0; i < tuple.size(); ++i) { 166 // Expected shape is [batch_size] + component_shapes_[i] 167 const TensorShape expected_shape = ManyOutShape(i, batch_size); 168 if (!expected_shape.IsSameSize(tuple[i].shape())) { 169 return errors::InvalidArgument("Shape mismatch in tuple component ", i, 170 ". Expected ", 171 expected_shape.DebugString(), ", got ", 172 tuple[i].shape().DebugString()); 173 } 174 } 175 } else { 176 for (size_t i = 1; i < tuple.size(); ++i) { 177 if (tuple[i].dim_size(0) != batch_size) { 178 return errors::InvalidArgument( 179 "All input tensors must have the same size in the 0th ", 180 "dimension. Component ", i, " has ", tuple[i].dim_size(0), 181 ", and should have ", batch_size); 182 } 183 } 184 } 185 return Status::OK(); 186 } 187 188 void QueueBase::Cancel(Action action, CancellationManager* cancellation_manager, 189 CancellationToken token) { 190 DoneCallback callback = nullptr; 191 { 192 mutex_lock lock(mu_); 193 std::deque<Attempt>* attempts = 194 action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_; 195 196 for (Attempt& attempt : *attempts) { 197 if (attempt.cancellation_manager == cancellation_manager && 198 attempt.cancellation_token == token) { 199 if (!attempt.is_cancelled) { 200 attempt.is_cancelled = true; 201 if (action == kEnqueue) { 202 attempt.context->SetStatus( 203 errors::Cancelled("Enqueue operation was cancelled")); 204 } else { 205 attempt.context->SetStatus( 206 errors::Cancelled("Dequeue operation was cancelled")); 207 } 208 std::swap(callback, attempt.done_callback); 209 } 210 break; 211 } 212 } 213 } 214 if (callback) { 215 callback(); 216 FlushUnlocked(); 217 } 218 } 219 220 void QueueBase::CloseAndCancel() { 221 std::vector<DoneCallback> callbacks; 222 { 223 mutex_lock lock(mu_); 224 closed_ = true; 225 for (Attempt& attempt : enqueue_attempts_) { 226 if (!attempt.is_cancelled) { 227 attempt.is_cancelled = true; 228 attempt.context->SetStatus( 229 errors::Cancelled("Enqueue operation was cancelled")); 230 callbacks.emplace_back(std::move(attempt.done_callback)); 231 } 232 } 233 } 234 for (const DoneCallback& callback : callbacks) { 235 callback(); 236 } 237 FlushUnlocked(); 238 } 239 240 void QueueBase::Close(OpKernelContext* ctx, bool cancel_pending_enqueues, 241 DoneCallback callback) { 242 if (cancel_pending_enqueues) { 243 CloseAndCancel(); 244 callback(); 245 } else { 246 { 247 mutex_lock lock(mu_); 248 enqueue_attempts_.emplace_back( 249 0, callback, ctx, nullptr, CancellationManager::kInvalidToken, 250 [this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) { 251 if (closed_) { 252 attempt->context->SetStatus( 253 errors::Cancelled("Queue '", name_, "' is already closed.")); 254 } else { 255 closed_ = true; 256 } 257 return kComplete; 258 }); 259 } 260 FlushUnlocked(); 261 } 262 } 263 264 bool QueueBase::TryAttemptLocked(Action action, 265 std::vector<CleanUp>* clean_up) { 266 std::deque<Attempt>* attempts = 267 action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_; 268 269 bool progress = false; 270 bool done = false; 271 while (!done && !attempts->empty()) { 272 if (attempts->front().is_cancelled) { 273 if (action == kEnqueue) { 274 if (closed_) { 275 VLOG(1) << "Skipping cancelled enqueue attempt"; 276 } else { 277 LOG(WARNING) 278 << name_ 279 << ": Skipping cancelled enqueue attempt with queue not closed"; 280 } 281 } else { 282 if (closed_) { 283 VLOG(1) << "Skipping cancelled dequeue attempt"; 284 } else { 285 LOG(WARNING) 286 << name_ 287 << ": Skipping cancelled dequeue attempt with queue not closed"; 288 } 289 } 290 attempts->pop_front(); 291 } else { 292 Attempt* cur_attempt = &attempts->front(); 293 switch (cur_attempt->run_callback(cur_attempt)) { 294 case kNoProgress: 295 done = true; 296 break; 297 case kProgress: 298 done = true; 299 progress = true; 300 break; 301 case kComplete: 302 progress = true; 303 clean_up->emplace_back(std::move(cur_attempt->done_callback), 304 cur_attempt->cancellation_token, 305 cur_attempt->context->cancellation_manager()); 306 attempts->pop_front(); 307 break; 308 } 309 } 310 } 311 return progress; 312 } 313 314 void QueueBase::FlushUnlocked() { 315 std::vector<CleanUp> clean_up; 316 Ref(); 317 { 318 mutex_lock lock(mu_); 319 bool changed; 320 do { 321 changed = TryAttemptLocked(kEnqueue, &clean_up); 322 changed = TryAttemptLocked(kDequeue, &clean_up) || changed; 323 } while (changed); 324 } 325 Unref(); 326 for (const auto& to_clean : clean_up) { 327 if (to_clean.to_deregister != CancellationManager::kInvalidToken) { 328 // NOTE(mrry): We can safely ignore the return value of 329 // DeregisterCallback because the mutex mu_ ensures that the 330 // cleanup action only executes once. 331 to_clean.cm->DeregisterCallback(to_clean.to_deregister); 332 } 333 to_clean.finished(); 334 } 335 } 336 337 Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element, 338 int64 index) { 339 return batch_util::CopySliceToElement(parent, element, index); 340 } 341 342 /* static */ 343 Status QueueBase::CopyElementToSlice(const Tensor& element, Tensor* parent, 344 int64 index) { 345 return batch_util::CopyElementToSlice(element, parent, index); 346 } 347 348 } // namespace tensorflow 349