Home | History | Annotate | Download | only in common_runtime
      1 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_
      2 #define TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_
      3 
      4 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      5 
      6 Licensed under the Apache License, Version 2.0 (the "License");
      7 you may not use this file except in compliance with the License.
      8 You may obtain a copy of the License at
      9 
     10     http://www.apache.org/licenses/LICENSE-2.0
     11 
     12 Unless required by applicable law or agreed to in writing, software
     13 distributed under the License is distributed on an "AS IS" BASIS,
     14 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     15 See the License for the specific language governing permissions and
     16 limitations under the License.
     17 ==============================================================================*/
     18 
     19 #include "tensorflow/core/lib/gtl/flatmap.h"
     20 #include "tensorflow/core/lib/hash/hash.h"
     21 #include "tensorflow/core/platform/logging.h"
     22 #include "tensorflow/core/platform/macros.h"
     23 #include "tensorflow/core/util/port.h"
     24 
     25 namespace tensorflow {
     26 
     27 // PendingCounts is an internal helper class to keep track of pending and
     28 // dead counts for nodes, for use in the ExecutorState module.  It
     29 // holds a map from Handles to various counts for that handle.  This
     30 // information is needed per frame iteration. The amount of memory
     31 // needed for an iteration is the same across all executions of the
     32 // iteration. The memory amount and handles are precomputed at startup
     33 // using a Layout object.
     34 //
     35 //    PendingCounts::Layout layout;
     36 //    std::vector<PendingCounts::Handle> h(C);
     37 //    for (int id = 0; id < C; id++) {
     38 //      h[id] = r.AddHandle(max_pending[id], max_dead[id]);
     39 //    }
     40 //
     41 // When we actually want to start an iteration we first create a
     42 // PendingCounts object and then index into it using the precomputed
     43 // handles:
     44 
     45 //    PendingCounts counts(layout);
     46 //    ...
     47 //    counts.decrement_pending(h[id], 1);
     48 class PendingCounts {
     49  public:
     50   // The state machine for a node's execution.
     51   enum NodeState {
     52     // The pending count for the node > 0.
     53     PENDING_NOTREADY,
     54     // The pending count for the node == 0, but the node has not
     55     // started executing.
     56     PENDING_READY,
     57     // The node has started executing.
     58     STARTED,
     59     // The node has finished executing.
     60     COMPLETED
     61   };
     62 
     63   // An opaque handle indicating where in the PendingCounts data structure
     64   // the appropriate count information can be found.
     65   class Handle;
     66   // Given a node that needs to represent counts no larger than the
     67   // specified "max_pending_count" and "max_dead_count", create a
     68   // handle that can be passed to various PendingCounts routines
     69   // to retrieve the count data for this node.
     70   class Layout {
     71    public:
     72     Handle CreateHandle(size_t max_pending_count, size_t max_dead_count);
     73 
     74    private:
     75     friend class PendingCounts;
     76     int next_offset_ = 0;  // Next byte offset to allocate
     77   };
     78 
     79   // Create a new PendingCounts object that can hold the state of
     80   // all the Handles allocated from "final_allocator".
     81   explicit PendingCounts(Layout layout)
     82       : num_bytes_(layout.next_offset_), bytes_(new char[num_bytes_]) {}
     83 
     84   // Create a new PendingCounts object with the same layout and counts
     85   // as "other".
     86   explicit PendingCounts(const PendingCounts& other)
     87       : num_bytes_(other.num_bytes_), bytes_(new char[num_bytes_]) {
     88     CHECK_EQ(uintptr_t(bytes_) % alignof(LargeCounts), 0);
     89     memcpy(bytes_, other.bytes_, other.num_bytes_);
     90   }
     91 
     92   ~PendingCounts() { delete[] bytes_; }
     93 
     94   void set_initial_count(Handle h, size_t pending_count) {
     95     if (h.is_large_) {
     96       LargeCounts* c = Large(h);
     97       c->pending = pending_count;
     98       c->dead_count = 0;
     99       c->has_started = 0;
    100     } else {
    101       PackedCounts* c = Packed(h);
    102       DCHECK_LE(pending_count, kMaxCountForPackedCounts);
    103       c->pending = pending_count;
    104       c->dead_count = 0;
    105       c->has_started = 0;
    106     }
    107   }
    108 
    109   NodeState node_state(Handle h) {
    110     if (h.is_large_) {
    111       return NodeStateForStruct(Large(h));
    112     } else {
    113       return NodeStateForStruct(Packed(h));
    114     }
    115   }
    116   void mark_started(Handle h) {
    117     DCHECK_EQ(pending(h), 0);
    118     if (h.is_large_) {
    119       LargeCounts* c = Large(h);
    120       DCHECK_EQ(c->has_started, 0);
    121       c->has_started = 1;
    122     } else {
    123       PackedCounts* c = Packed(h);
    124       DCHECK_EQ(c->has_started, 0);
    125       c->has_started = 1;
    126     }
    127   }
    128   void mark_completed(Handle h) {
    129     if (h.is_large_) {
    130       LargeCounts* c = Large(h);
    131       DCHECK_EQ(c->has_started, 1);
    132       c->pending = 1;
    133     } else {
    134       PackedCounts* c = Packed(h);
    135       DCHECK_EQ(c->has_started, 1);
    136       c->pending = 1;
    137     }
    138   }
    139   int pending(Handle h) {
    140     if (h.is_large_) {
    141       LargeCounts* c = Large(h);
    142       if (PENDING_NOTREADY == NodeStateForStruct(c)) {
    143         return c->pending;
    144       } else {
    145         // The pending count encodes the state once the node has
    146         // started, so just return 0.
    147         return 0;
    148       }
    149     } else {
    150       PackedCounts* c = Packed(h);
    151       if (PENDING_NOTREADY == NodeStateForStruct(c)) {
    152         return c->pending;
    153       } else {
    154         // The pending count encodes the state once the node has
    155         // started, so just return 0.
    156         return 0;
    157       }
    158     }
    159   }
    160   int decrement_pending(Handle h, int v) {
    161     DCHECK_GE(pending(h), v);
    162     if (h.is_large_) {
    163       LargeCounts* c = Large(h);
    164       c->pending -= v;
    165       return c->pending;
    166     } else {
    167       PackedCounts* c = Packed(h);
    168       c->pending -= v;
    169       return c->pending;
    170     }
    171   }
    172   // Mark a merge node as live
    173   // REQUIRES: Node corresponding to "h" is a merge node
    174   void mark_live(Handle h) {
    175     if (h.is_large_) {
    176       LargeCounts* c = Large(h);
    177       // Only do anything if the node hasn't already started executing.
    178       if (PENDING_NOTREADY == NodeStateForStruct(c)) {
    179         c->pending &= ~static_cast<int>(0x1);
    180       }
    181     } else {
    182       PackedCounts* c = Packed(h);
    183       // Only do anything if the node hasn't already started executing.
    184       if (PENDING_NOTREADY == NodeStateForStruct(c)) {
    185         static_assert(7 == kMaxCountForPackedCounts,
    186                       "Live flag incorrect for max packed count");
    187         c->pending &= 0x6;
    188       }
    189     }
    190   }
    191 
    192   int dead_count(Handle h) {
    193     int r = h.is_large_ ? Large(h)->dead_count : Packed(h)->dead_count;
    194     return r;
    195   }
    196   void increment_dead_count(Handle h) {
    197     if (h.is_large_) {
    198       LargeCounts* c = Large(h);
    199       if (PENDING_NOTREADY == NodeStateForStruct(c)) {
    200         c->dead_count++;
    201       }
    202     } else {
    203       PackedCounts* c = Packed(h);
    204       if (PENDING_NOTREADY == NodeStateForStruct(c)) {
    205         DCHECK_LT(c->dead_count, kMaxCountForPackedCounts);
    206         c->dead_count++;
    207       }
    208     }
    209   }
    210 
    211   // A streamlined routine that does several pieces of bookkeeping at
    212   // once.  Equivalent to:
    213   //    if (increment_dead) increment_dead_count(h);
    214   //    decrement_pending(h, 1);
    215   //    *pending_result = pending(h);
    216   //    *dead_result = dead_count(h);
    217   void adjust_for_activation(Handle h, bool increment_dead, int* pending_result,
    218                              int* dead_result) {
    219     DCHECK_GE(pending(h), 1);
    220     if (h.is_large_) {
    221       adjust_for_activation_shared(Large(h), increment_dead, pending_result,
    222                                    dead_result);
    223     } else {
    224       adjust_for_activation_shared(Packed(h), increment_dead, pending_result,
    225                                    dead_result);
    226     }
    227   }
    228 
    229   class Handle {
    230    public:
    231     Handle() : byte_offset_(0), is_large_(0) {}
    232 
    233    private:
    234     friend class PendingCounts;
    235     int byte_offset_ : 31;  // Byte offset of the rep in PendingCounts object
    236     bool is_large_ : 1;  // If true, rep is LargeCounts; otherwise PackedCounts
    237   };
    238 
    239  private:
    240   template <typename T>
    241   inline void adjust_for_activation_shared(T* c, bool increment_dead,
    242                                            int* pending_result,
    243                                            int* dead_result) {
    244     if (increment_dead) {
    245       if (PENDING_NOTREADY == NodeStateForStruct(c)) {
    246         c->dead_count++;
    247       }
    248     }
    249     c->pending -= 1;
    250     *dead_result = c->dead_count;
    251     *pending_result = c->pending;
    252   }
    253 
    254   // We keep track of the pending count and dead input count for each
    255   // graph node.  The representation used here is designed to be cache
    256   // efficient for graphs with large numbers of nodes, where most
    257   // nodes have relatively small maximum pending counts (e.g. for one
    258   // LSTM model, 99% of 5000+ nodes had in-degrees of 3 or less).  We
    259   // use one byte to hold both the pending and dead count for a node
    260   // where these together can fit in one byte, and we use a hash table
    261   // to handle the rare node ids that need larger counts than this.
    262   // Each frame in this subgraph has its own PendingCounts.
    263 
    264   // We use 3 bits each for dead_count and pending.
    265   static const int kMaxCountForPackedCounts = 7;
    266 
    267   // Most counts are small, so we pack a pending count and a dead
    268   // count into 3 bits each, use 1 bit to indicate that the node has
    269   // started computing.
    270   struct PackedCounts {
    271     uint8 pending : 3;
    272     uint8 dead_count : 3;
    273     uint8 has_started : 1;
    274   };
    275 
    276   struct LargeCounts {
    277     uint32 pending;
    278     uint32 dead_count : 31;
    279     uint8 has_started : 1;
    280   };
    281 
    282   template <typename T>
    283   NodeState NodeStateForStruct(T* c) const {
    284     if (c->has_started) {
    285       return (c->pending == 0) ? STARTED : COMPLETED;
    286     } else {
    287       return (c->pending == 0) ? PENDING_READY : PENDING_NOTREADY;
    288     }
    289   }
    290   inline LargeCounts* Large(Handle h) {
    291     DCHECK(h.is_large_);
    292     DCHECK_LE(h.byte_offset_ + sizeof(LargeCounts), num_bytes_);
    293     DCHECK_EQ(h.byte_offset_ % alignof(LargeCounts), 0);
    294     return reinterpret_cast<LargeCounts*>(bytes_ + h.byte_offset_);
    295   }
    296   inline PackedCounts* Packed(Handle h) {
    297     DCHECK(!h.is_large_);
    298     DCHECK_LE(h.byte_offset_ + sizeof(PackedCounts), num_bytes_);
    299     return reinterpret_cast<PackedCounts*>(bytes_ + h.byte_offset_);
    300   }
    301 
    302   const int num_bytes_;  // Just for bounds checking in debug mode
    303   char* bytes_;          // Array of num_bytes_ bytes
    304 
    305   void operator=(const PendingCounts&) = delete;
    306 };
    307 
    308 inline PendingCounts::Handle PendingCounts::Layout::CreateHandle(
    309     size_t max_pending_count, size_t max_dead_count) {
    310   Handle result;
    311   if ((max_pending_count > kMaxCountForPackedCounts) ||
    312       (max_dead_count > kMaxCountForPackedCounts)) {
    313     int B = sizeof(LargeCounts);
    314     // Round byte offset to proper alignment
    315     DCHECK_GE(sizeof(LargeCounts), alignof(LargeCounts));
    316     int64 offset = ((static_cast<int64>(next_offset_) + B - 1) / B) * B;
    317     result.byte_offset_ = offset;
    318     result.is_large_ = true;
    319     next_offset_ = result.byte_offset_ + B;
    320   } else {
    321     result.byte_offset_ = next_offset_;
    322     result.is_large_ = false;
    323     DCHECK_EQ(sizeof(PackedCounts), 1);
    324     next_offset_ += sizeof(PackedCounts);
    325   }
    326   return result;
    327 }
    328 
    329 }  // end namespace tensorflow
    330 
    331 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_
    332