1 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_ 2 #define TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_ 3 4 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 5 6 Licensed under the Apache License, Version 2.0 (the "License"); 7 you may not use this file except in compliance with the License. 8 You may obtain a copy of the License at 9 10 http://www.apache.org/licenses/LICENSE-2.0 11 12 Unless required by applicable law or agreed to in writing, software 13 distributed under the License is distributed on an "AS IS" BASIS, 14 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 See the License for the specific language governing permissions and 16 limitations under the License. 17 ==============================================================================*/ 18 19 #include "tensorflow/core/lib/gtl/flatmap.h" 20 #include "tensorflow/core/lib/hash/hash.h" 21 #include "tensorflow/core/platform/logging.h" 22 #include "tensorflow/core/platform/macros.h" 23 #include "tensorflow/core/util/port.h" 24 25 namespace tensorflow { 26 27 // PendingCounts is an internal helper class to keep track of pending and 28 // dead counts for nodes, for use in the ExecutorState module. It 29 // holds a map from Handles to various counts for that handle. This 30 // information is needed per frame iteration. The amount of memory 31 // needed for an iteration is the same across all executions of the 32 // iteration. The memory amount and handles are precomputed at startup 33 // using a Layout object. 34 // 35 // PendingCounts::Layout layout; 36 // std::vector<PendingCounts::Handle> h(C); 37 // for (int id = 0; id < C; id++) { 38 // h[id] = r.AddHandle(max_pending[id], max_dead[id]); 39 // } 40 // 41 // When we actually want to start an iteration we first create a 42 // PendingCounts object and then index into it using the precomputed 43 // handles: 44 45 // PendingCounts counts(layout); 46 // ... 47 // counts.decrement_pending(h[id], 1); 48 class PendingCounts { 49 public: 50 // The state machine for a node's execution. 51 enum NodeState { 52 // The pending count for the node > 0. 53 PENDING_NOTREADY, 54 // The pending count for the node == 0, but the node has not 55 // started executing. 56 PENDING_READY, 57 // The node has started executing. 58 STARTED, 59 // The node has finished executing. 60 COMPLETED 61 }; 62 63 // An opaque handle indicating where in the PendingCounts data structure 64 // the appropriate count information can be found. 65 class Handle; 66 // Given a node that needs to represent counts no larger than the 67 // specified "max_pending_count" and "max_dead_count", create a 68 // handle that can be passed to various PendingCounts routines 69 // to retrieve the count data for this node. 70 class Layout { 71 public: 72 Handle CreateHandle(size_t max_pending_count, size_t max_dead_count); 73 74 private: 75 friend class PendingCounts; 76 int next_offset_ = 0; // Next byte offset to allocate 77 }; 78 79 // Create a new PendingCounts object that can hold the state of 80 // all the Handles allocated from "final_allocator". 81 explicit PendingCounts(Layout layout) 82 : num_bytes_(layout.next_offset_), bytes_(new char[num_bytes_]) {} 83 84 // Create a new PendingCounts object with the same layout and counts 85 // as "other". 86 explicit PendingCounts(const PendingCounts& other) 87 : num_bytes_(other.num_bytes_), bytes_(new char[num_bytes_]) { 88 CHECK_EQ(uintptr_t(bytes_) % alignof(LargeCounts), 0); 89 memcpy(bytes_, other.bytes_, other.num_bytes_); 90 } 91 92 ~PendingCounts() { delete[] bytes_; } 93 94 void set_initial_count(Handle h, size_t pending_count) { 95 if (h.is_large_) { 96 LargeCounts* c = Large(h); 97 c->pending = pending_count; 98 c->dead_count = 0; 99 c->has_started = 0; 100 } else { 101 PackedCounts* c = Packed(h); 102 DCHECK_LE(pending_count, kMaxCountForPackedCounts); 103 c->pending = pending_count; 104 c->dead_count = 0; 105 c->has_started = 0; 106 } 107 } 108 109 NodeState node_state(Handle h) { 110 if (h.is_large_) { 111 return NodeStateForStruct(Large(h)); 112 } else { 113 return NodeStateForStruct(Packed(h)); 114 } 115 } 116 void mark_started(Handle h) { 117 DCHECK_EQ(pending(h), 0); 118 if (h.is_large_) { 119 LargeCounts* c = Large(h); 120 DCHECK_EQ(c->has_started, 0); 121 c->has_started = 1; 122 } else { 123 PackedCounts* c = Packed(h); 124 DCHECK_EQ(c->has_started, 0); 125 c->has_started = 1; 126 } 127 } 128 void mark_completed(Handle h) { 129 if (h.is_large_) { 130 LargeCounts* c = Large(h); 131 DCHECK_EQ(c->has_started, 1); 132 c->pending = 1; 133 } else { 134 PackedCounts* c = Packed(h); 135 DCHECK_EQ(c->has_started, 1); 136 c->pending = 1; 137 } 138 } 139 int pending(Handle h) { 140 if (h.is_large_) { 141 LargeCounts* c = Large(h); 142 if (PENDING_NOTREADY == NodeStateForStruct(c)) { 143 return c->pending; 144 } else { 145 // The pending count encodes the state once the node has 146 // started, so just return 0. 147 return 0; 148 } 149 } else { 150 PackedCounts* c = Packed(h); 151 if (PENDING_NOTREADY == NodeStateForStruct(c)) { 152 return c->pending; 153 } else { 154 // The pending count encodes the state once the node has 155 // started, so just return 0. 156 return 0; 157 } 158 } 159 } 160 int decrement_pending(Handle h, int v) { 161 DCHECK_GE(pending(h), v); 162 if (h.is_large_) { 163 LargeCounts* c = Large(h); 164 c->pending -= v; 165 return c->pending; 166 } else { 167 PackedCounts* c = Packed(h); 168 c->pending -= v; 169 return c->pending; 170 } 171 } 172 // Mark a merge node as live 173 // REQUIRES: Node corresponding to "h" is a merge node 174 void mark_live(Handle h) { 175 if (h.is_large_) { 176 LargeCounts* c = Large(h); 177 // Only do anything if the node hasn't already started executing. 178 if (PENDING_NOTREADY == NodeStateForStruct(c)) { 179 c->pending &= ~static_cast<int>(0x1); 180 } 181 } else { 182 PackedCounts* c = Packed(h); 183 // Only do anything if the node hasn't already started executing. 184 if (PENDING_NOTREADY == NodeStateForStruct(c)) { 185 static_assert(7 == kMaxCountForPackedCounts, 186 "Live flag incorrect for max packed count"); 187 c->pending &= 0x6; 188 } 189 } 190 } 191 192 int dead_count(Handle h) { 193 int r = h.is_large_ ? Large(h)->dead_count : Packed(h)->dead_count; 194 return r; 195 } 196 void increment_dead_count(Handle h) { 197 if (h.is_large_) { 198 LargeCounts* c = Large(h); 199 if (PENDING_NOTREADY == NodeStateForStruct(c)) { 200 c->dead_count++; 201 } 202 } else { 203 PackedCounts* c = Packed(h); 204 if (PENDING_NOTREADY == NodeStateForStruct(c)) { 205 DCHECK_LT(c->dead_count, kMaxCountForPackedCounts); 206 c->dead_count++; 207 } 208 } 209 } 210 211 // A streamlined routine that does several pieces of bookkeeping at 212 // once. Equivalent to: 213 // if (increment_dead) increment_dead_count(h); 214 // decrement_pending(h, 1); 215 // *pending_result = pending(h); 216 // *dead_result = dead_count(h); 217 void adjust_for_activation(Handle h, bool increment_dead, int* pending_result, 218 int* dead_result) { 219 DCHECK_GE(pending(h), 1); 220 if (h.is_large_) { 221 adjust_for_activation_shared(Large(h), increment_dead, pending_result, 222 dead_result); 223 } else { 224 adjust_for_activation_shared(Packed(h), increment_dead, pending_result, 225 dead_result); 226 } 227 } 228 229 class Handle { 230 public: 231 Handle() : byte_offset_(0), is_large_(0) {} 232 233 private: 234 friend class PendingCounts; 235 int byte_offset_ : 31; // Byte offset of the rep in PendingCounts object 236 bool is_large_ : 1; // If true, rep is LargeCounts; otherwise PackedCounts 237 }; 238 239 private: 240 template <typename T> 241 inline void adjust_for_activation_shared(T* c, bool increment_dead, 242 int* pending_result, 243 int* dead_result) { 244 if (increment_dead) { 245 if (PENDING_NOTREADY == NodeStateForStruct(c)) { 246 c->dead_count++; 247 } 248 } 249 c->pending -= 1; 250 *dead_result = c->dead_count; 251 *pending_result = c->pending; 252 } 253 254 // We keep track of the pending count and dead input count for each 255 // graph node. The representation used here is designed to be cache 256 // efficient for graphs with large numbers of nodes, where most 257 // nodes have relatively small maximum pending counts (e.g. for one 258 // LSTM model, 99% of 5000+ nodes had in-degrees of 3 or less). We 259 // use one byte to hold both the pending and dead count for a node 260 // where these together can fit in one byte, and we use a hash table 261 // to handle the rare node ids that need larger counts than this. 262 // Each frame in this subgraph has its own PendingCounts. 263 264 // We use 3 bits each for dead_count and pending. 265 static const int kMaxCountForPackedCounts = 7; 266 267 // Most counts are small, so we pack a pending count and a dead 268 // count into 3 bits each, use 1 bit to indicate that the node has 269 // started computing. 270 struct PackedCounts { 271 uint8 pending : 3; 272 uint8 dead_count : 3; 273 uint8 has_started : 1; 274 }; 275 276 struct LargeCounts { 277 uint32 pending; 278 uint32 dead_count : 31; 279 uint8 has_started : 1; 280 }; 281 282 template <typename T> 283 NodeState NodeStateForStruct(T* c) const { 284 if (c->has_started) { 285 return (c->pending == 0) ? STARTED : COMPLETED; 286 } else { 287 return (c->pending == 0) ? PENDING_READY : PENDING_NOTREADY; 288 } 289 } 290 inline LargeCounts* Large(Handle h) { 291 DCHECK(h.is_large_); 292 DCHECK_LE(h.byte_offset_ + sizeof(LargeCounts), num_bytes_); 293 DCHECK_EQ(h.byte_offset_ % alignof(LargeCounts), 0); 294 return reinterpret_cast<LargeCounts*>(bytes_ + h.byte_offset_); 295 } 296 inline PackedCounts* Packed(Handle h) { 297 DCHECK(!h.is_large_); 298 DCHECK_LE(h.byte_offset_ + sizeof(PackedCounts), num_bytes_); 299 return reinterpret_cast<PackedCounts*>(bytes_ + h.byte_offset_); 300 } 301 302 const int num_bytes_; // Just for bounds checking in debug mode 303 char* bytes_; // Array of num_bytes_ bytes 304 305 void operator=(const PendingCounts&) = delete; 306 }; 307 308 inline PendingCounts::Handle PendingCounts::Layout::CreateHandle( 309 size_t max_pending_count, size_t max_dead_count) { 310 Handle result; 311 if ((max_pending_count > kMaxCountForPackedCounts) || 312 (max_dead_count > kMaxCountForPackedCounts)) { 313 int B = sizeof(LargeCounts); 314 // Round byte offset to proper alignment 315 DCHECK_GE(sizeof(LargeCounts), alignof(LargeCounts)); 316 int64 offset = ((static_cast<int64>(next_offset_) + B - 1) / B) * B; 317 result.byte_offset_ = offset; 318 result.is_large_ = true; 319 next_offset_ = result.byte_offset_ + B; 320 } else { 321 result.byte_offset_ = next_offset_; 322 result.is_large_ = false; 323 DCHECK_EQ(sizeof(PackedCounts), 1); 324 next_offset_ += sizeof(PackedCounts); 325 } 326 return result; 327 } 328 329 } // end namespace tensorflow 330 331 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_ 332