Home | History | Annotate | Download | only in framework
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #ifndef TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_
     16 #define TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_
     17 
     18 #include <string>
     19 #include <vector>
     20 
     21 #include "tensorflow/core/framework/device_attributes.pb.h"
     22 #include "tensorflow/core/framework/device_base.h"
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/lib/core/refcount.h"
     25 #include "tensorflow/core/lib/core/status.h"
     26 
     27 namespace tensorflow {
     28 class BufRendezvous;
     29 class CancellationManager;
     30 class CompleteGroupRequest;
     31 class CompleteGroupResponse;
     32 class CompleteInstanceRequest;
     33 class CompleteInstanceResponse;
     34 class Device;
     35 class DeviceMgr;
     36 class GetStepSequenceRequest;
     37 class GetStepSequenceResponse;
     38 class Op;
     39 class Tensor;
     40 
     41 // Types of supported collective operations.
     42 enum CollectiveType {
     43   REDUCTION_COLLECTIVE = 0,
     44   BROADCAST_COLLECTIVE,
     45   GATHER_COLLECTIVE,
     46   UNDEFINED_COLLECTIVE,
     47 };
     48 
     49 // Data common to all members of a device group.
     50 // All members share the same device set but its order is
     51 // particular to an instance so it is stored there.
     52 struct CollGroupParams {
     53   int32 group_key;
     54   int32 group_size;
     55   DeviceType device_type;
     56   int32 num_tasks;  // number of distinct tasks in group
     57   string ToString() const;
     58   CollGroupParams()
     59       : group_key(0), group_size(0), device_type(DEVICE_CPU), num_tasks(0) {}
     60 };
     61 
     62 // The best implementation of a collective op depends on many factors
     63 // including the number of devices involved, the topology of
     64 // interconnects between them and the sizes of inputs.  This structure
     65 // is used in generating and representing data movement choreography
     66 // for each specific algorithm, hence it does not have a single, fixed
     67 // interpretation.  On first execution the runtime will update this
     68 // structure with decisions that will guide all subsequent executions.
     69 struct CollImplDetails {
     70   string collective_name;
     71   std::vector<std::vector<int>> subdiv_permutations;
     72   std::vector<int> subdiv_offsets;
     73   std::vector<int> subdiv_source_rank;  // rank of source in each subdiv
     74   std::vector<int32>
     75       dependencies;  // collective instances on which this node depends
     76 };
     77 
     78 // Data common to all members of a collective instance.
     79 struct CollInstanceParams {
     80   // Identifies all participating graph nodes.
     81   int32 instance_key = -1;
     82   CollectiveType type = UNDEFINED_COLLECTIVE;
     83   DataType data_type = DT_FLOAT;
     84   TensorShape shape = {0};
     85   // Fully qualified name of device for each member, in default rank order.
     86   std::vector<string> device_names;
     87   // Task name prefix of corresponding device name.
     88   std::vector<string> task_names;
     89   // True if every task has the same number of devices.
     90   bool same_num_devices_per_task = false;
     91   // Task -> number of devices on that task.
     92   std::unordered_map<string, int32> num_devices_per_task;
     93   // If passed in to GPUOptions in ConfigProto, defines a good ring order for
     94   // GPUs.  Assumes same GPU configuration at each worker.
     95   string gpu_ring_order = "";
     96   // Valid when using a communicator-based collective mechanism, e.g. NCCL.
     97   string communicator_key;
     98   CollImplDetails impl_details;
     99   string ToString() const;
    100   CollInstanceParams& operator=(const struct CollInstanceParams& other);
    101 };
    102 
    103 // Data common to all instance members in the same task.
    104 struct CollTaskParams {
    105   // True for devices that are local to the process, i.e. no RPC needed.
    106   std::vector<bool> is_local;
    107   string ToString() const;
    108 };
    109 
    110 // Unique to a single CollectiveOp node.
    111 struct CollectiveParams {
    112   CollGroupParams group;
    113   CollInstanceParams instance;
    114   CollTaskParams task;
    115 
    116   string name = "";        // node name used only for log or error messages
    117   int default_rank = -1;   // index of this op within device_names
    118   bool is_source = false;  // broadcast only
    119   int source_rank = -1;    // broadcast only
    120   // Rank of this device in each subdivision permutation.
    121   std::vector<int> subdiv_rank;
    122   std::unique_ptr<OpKernel> merge_op;  // reduction only
    123   std::unique_ptr<OpKernel> final_op;  // reduction only
    124   string ToString() const;
    125 };
    126 
    127 class CollectiveExecutor;
    128 
    129 // Interface that provides resolution of device localities.
    130 class DeviceResolverInterface {
    131  public:
    132   virtual ~DeviceResolverInterface() {}
    133 
    134   // Collects DeviceLocality protobufs from all of the devices identified
    135   // in 'col_params'.
    136   virtual void GetDeviceLocalitiesAsync(const CollInstanceParams& inst_params,
    137                                         std::vector<DeviceLocality>* localities,
    138                                         const StatusCallback& done) = 0;
    139 
    140   // Populate *locality with the DeviceLocality of the specified
    141   // device.
    142   virtual void GetLocalityAsync(const string& device, const string& task,
    143                                 DeviceLocality* locality,
    144                                 const StatusCallback& done) = 0;
    145 
    146   // Clear the cache of device data belonging
    147   // to the specified task.
    148   virtual void ClearTask(const string& task) = 0;
    149 };
    150 
    151 // Interface that provides resolution of shared CollectiveParams fields.
    152 class ParamResolverInterface {
    153  public:
    154   virtual ~ParamResolverInterface() {}
    155 
    156   // Called by each collective op at first execution in order to fill out
    157   // the CollectiveParams structure with data gathered from the full
    158   // (maybe distributed) collection of peer nodes.
    159   virtual void CompleteParamsAsync(const string& device, CollectiveParams* cp,
    160                                    CancellationManager* cancel_mgr,
    161                                    const StatusCallback& done) = 0;
    162 
    163   // Used within a distributed implementation to discover/verify
    164   // data shared across a device group.
    165   virtual void CompleteGroupAsync(const CompleteGroupRequest* request,
    166                                   CompleteGroupResponse* response,
    167                                   CancellationManager* cancel_mgr,
    168                                   const StatusCallback& done) = 0;
    169 
    170   // Used within a distributed implementation to discover/verify data
    171   // shared across an instance group.
    172   virtual void CompleteInstanceAsync(const CompleteInstanceRequest* request,
    173                                      CompleteInstanceResponse* response,
    174                                      CancellationManager* cancel_mgr,
    175                                      const StatusCallback& done) = 0;
    176 };
    177 
    178 // Graphs which utilize Collective Ops in a common instance must
    179 // execute with identical step_ids even if they are disjoint graphs
    180 // run by otherwise independent tasks.  This interface supplies
    181 // coordinated step_ids to use in such cases.
    182 class StepSequenceInterface {
    183  public:
    184   virtual ~StepSequenceInterface() {}
    185 
    186   // Used with a distributed implementation to coordinate step_id
    187   // sequences across tasks.
    188   virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request,
    189                                     GetStepSequenceResponse* response,
    190                                     const StatusCallback& done) = 0;
    191 
    192   // Refresh the local per-graph_key step_id sequence from collective
    193   // group leader, if applicable.
    194   virtual void RefreshStepIdSequenceAsync(int64 graph_key,
    195                                           const StatusCallback& done) = 0;
    196 
    197   // Returns the step_id that should be used for initiating a new execution
    198   // on the specified graph. May return the same step_id multiple times if
    199   // RetireStepId or RefreshStepIdReservation is not called.
    200   virtual int64 NextStepId(int64 graph_key) = 0;
    201 
    202   // Reports that execution of the given step has completed successfully.
    203   // Should be called immediately after a step completes with OK status,
    204   // prior to calling NextStepId().  If the step fails, don't call.
    205   virtual void RetireStepId(int64 graph_key, int64 step_id) = 0;
    206 };
    207 
    208 // Interface that provides access to per-step CollectiveExecutor
    209 // instances and various distributed resolution capabilities.
    210 class CollectiveExecutorMgrInterface : public StepSequenceInterface {
    211  public:
    212   virtual ~CollectiveExecutorMgrInterface() {}
    213 
    214   // Returns the step-specific CollectiveExecutor, creating if one does not
    215   // already exist.  The caller assumes ownership of one Ref on the object.
    216   virtual CollectiveExecutor* FindOrCreate(int64 step_id) = 0;
    217 
    218   // If there is a CollectiveExecutor for step_id, remove it from the
    219   // table.
    220   virtual void Cleanup(int64 step_id) = 0;
    221 
    222   virtual ParamResolverInterface* GetParamResolver() const = 0;
    223 
    224   virtual DeviceResolverInterface* GetDeviceResolver() const = 0;
    225 };
    226 
    227 // Interface that a Collective Op implementation uses to exchange data
    228 // with peers.  Note that data exchange is currently limited to types
    229 // for which DMAHelper::CanUseDMA() returns true, i.e.  dense numeric
    230 // types.
    231 class PeerAccessInterface {
    232  public:
    233   virtual ~PeerAccessInterface() {}
    234 
    235   virtual void RecvFromPeer(const string& peer_device, const string& peer_task,
    236                             bool peer_is_local, const string& key,
    237                             Device* to_device, DeviceContext* to_device_ctx,
    238                             const AllocatorAttributes& to_alloc_attr,
    239                             Tensor* to_tensor,
    240                             const DeviceLocality& client_locality,
    241                             int dev_to_dev_stream_index,
    242                             const StatusCallback& done) = 0;
    243 
    244   virtual void PostToPeer(const string& peer_device, const string& peer_task,
    245                           const string& key, Device* from_device,
    246                           DeviceContext* from_device_ctx,
    247                           const AllocatorAttributes& from_alloc_attr,
    248                           const Tensor* from_tensor,
    249                           const DeviceLocality& client_locality,
    250                           const StatusCallback& done) = 0;
    251 };
    252 
    253 class PerStepCollectiveRemoteAccess;
    254 
    255 // A step-specific object that can execute a collective operation completely
    256 // described by a CollectiveParams object.
    257 class CollectiveExecutor : public PeerAccessInterface, public core::RefCounted {
    258  public:
    259   virtual void StartAbort(const Status& s) {}
    260 
    261   virtual void ExecuteAsync(OpKernelContext* ctx,
    262                             const CollectiveParams& col_params,
    263                             const string& exec_key, StatusCallback done) {
    264     done(errors::Internal(
    265         "A collective Op has been called in a context in which "
    266         "a CollectiveExecutor has not been provided."));
    267   }
    268 
    269   virtual void CompleteParamsAsync(const string& device, CollectiveParams* cp,
    270                                    CancellationManager* cancel_mgr,
    271                                    StatusCallback done) {
    272     done(errors::Internal(
    273         "A collective Op has been called in a context in which "
    274         "a CollectiveExecutor has not been provided."));
    275   }
    276 
    277   virtual PerStepCollectiveRemoteAccess* remote_access() { return nullptr; }
    278 
    279   // `WaitForDependencies` and `Launched` are used for fine-grained control of
    280   // execution order between collective instances.  These functions are intended
    281   // to be called in `Run` function of collective implementations, and may be
    282   // used to make part, or whole, of the collective execution ordered with
    283   // respect to other collective instances.
    284   //
    285   // `WaitForDependencies` will block until it is safe to continue the callee's
    286   // execution, where safety is defined as: ordered with respect to the
    287   // collective instances defined in the callee's `wait_for` attribute.
    288   virtual void WaitForDependencies(const CollectiveParams& col_params) {}
    289   // `Launched` unblocks the dependent collective instances by recording that
    290   // this callee device has completed the critical portion of the collective
    291   // execution.
    292   virtual void Launched(const CollectiveParams& col_params) {}
    293 
    294   // Used to designate an invalid group or instance key.
    295   static int64 kInvalidId;
    296 
    297   // Lexically scoped handle for Ref.
    298   class Handle {
    299    public:
    300     explicit Handle(CollectiveExecutor* ce, bool inherit_ref) : ce_(ce) {
    301       if (!inherit_ref) ce->Ref();
    302     }
    303     ~Handle() { ce_->Unref(); }
    304     CollectiveExecutor* get() const { return ce_; }
    305 
    306    private:
    307     CollectiveExecutor* ce_;
    308   };
    309 
    310  protected:
    311   explicit CollectiveExecutor(CollectiveExecutorMgrInterface* cem)
    312       : cem_(cem) {}
    313 
    314   // For use only by derived classes
    315   static OpKernelContext::Params* CtxParams(OpKernelContext* ctx);
    316   CollectiveExecutorMgrInterface* cem_;
    317 
    318   TF_DISALLOW_COPY_AND_ASSIGN(CollectiveExecutor);
    319 };
    320 
    321 // Interface of a helper object that provides a CollectiveExecutor with
    322 // all of the remote access it needs.
    323 class CollectiveRemoteAccess : public PeerAccessInterface,
    324                                public DeviceResolverInterface {
    325  public:
    326   virtual ~CollectiveRemoteAccess() {}
    327 
    328   virtual BufRendezvous* buf_rendezvous() = 0;
    329 };
    330 
    331 // A per-step version of CollectiveRemoteAccess that cleans up outstanding
    332 // communications in case step execution is abandoned.
    333 class PerStepCollectiveRemoteAccess : public CollectiveRemoteAccess {
    334  public:
    335   virtual ~PerStepCollectiveRemoteAccess() {}
    336   virtual void StartAbort(const Status& s) = 0;
    337 };
    338 
    339 class CollectiveContext {
    340  public:
    341   CollectiveContext(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr,
    342                     OpKernelContext* ctx, OpKernelContext::Params* op_params,
    343                     const CollectiveParams& col_params, const string& exec_key,
    344                     int64 step_id, const Tensor* input, Tensor* output);
    345 
    346   virtual ~CollectiveContext() = default;
    347 
    348   CollectiveExecutor* col_exec;        // Not owned
    349   const DeviceMgr* dev_mgr;            // Not owned
    350   OpKernelContext* op_ctx;             // Not owned
    351   OpKernelContext::Params* op_params;  // Not owned
    352   const CollectiveParams& col_params;
    353   const string exec_key;
    354   const int64 step_id;
    355   const Tensor* input;  // Not owned
    356   Tensor* output;       // Not owned
    357   Device* device;       // The device for which this instance labors
    358   const string device_name;
    359   DeviceLocality device_locality;
    360 };
    361 
    362 // Interface of a Collective Op implementation.  Each specific CollectiveOp will
    363 // implement this interface and register the implementation via the
    364 // CollectiveRegistry detailed below.  See common_runtime/ring_reducer and
    365 // common_runtime/hierarchical_tree_broadcaster for examples.
    366 class CollectiveImplementationInterface {
    367  public:
    368   virtual ~CollectiveImplementationInterface() = default;
    369 
    370   // Initializes the portions of `col_params` specific to this
    371   // implementation.  Called exactly once for every Collective instance during
    372   // the CollectiveParams resolution process when the graph is first executed,
    373   // at the end of `CompleteInstanceLocal()`.
    374   // NOTE(ayushd): This is effectively a static function because it modifies the
    375   // `col_params` passed in and should not manipulate any data members.  However
    376   // because it is virtual and needs to be implemented by every derived class we
    377   // do not mark it as static.
    378   virtual Status InitializeCollectiveParams(CollectiveParams* col_params) = 0;
    379 
    380   // Prepares the CollectiveContext for executing this CollectiveImplementation.
    381   // Called from CollectiveExecutor right before calling Run().  The
    382   // CollectiveContext passed in must outlive the CollectiveImplementation
    383   // object.
    384   virtual Status InitializeCollectiveContext(CollectiveContext* col_ctx) = 0;
    385 
    386   // Initializes instance params at the beginning of `CompleteInstanceLocal()`,
    387   // unlike `InitializeCollectiveParams` which is called at the end.  This
    388   // function is called before all devices in the instance are discovered, and
    389   // may be used to broadcast data via the shared `InstanceRec` object in
    390   // collective param resolution to all devices.
    391   virtual Status InitializeInstanceBeforeGroupDiscovery(
    392       CollectiveParams* col_params) = 0;
    393 
    394   // Processes and moves data according to the logic of this Collective
    395   // implementation.  Relies on appropriate initialization of op-specific
    396   // CollectiveParams in InitializeCollectiveParams(), as well as appropriate
    397   // context initialization in InitializeCollectiveContext().
    398   virtual void Run(StatusCallback done) = 0;
    399 };
    400 
    401 // Static-methods only class for registering and looking up collective
    402 // implementations.
    403 class CollectiveRegistry {
    404  public:
    405   using Factory = std::function<CollectiveImplementationInterface*()>;
    406   // Looks up a previously registered CollectiveImplementation under
    407   // `collective_name`.  If found, creates an instance of the implementation and
    408   // assign to `implementation`.
    409   static Status Lookup(const string& collective_name,
    410                        CollectiveImplementationInterface** implementation);
    411 
    412   // Looks up a previously registered CollectiveImplementation under
    413   // `collective_name`.  If found, returns the static instance of this
    414   // implementation via `implementation`.  This instance should only be used to
    415   // call InitializateCollectiveParams.
    416   static Status LookupParamResolverInstance(
    417       const string& collective_name,
    418       CollectiveImplementationInterface** implementation);
    419 
    420   // Returns all registered collective implementations.
    421   static void GetAll(
    422       std::vector<CollectiveImplementationInterface*>* implementations);
    423 
    424  private:
    425   friend class CollectiveRegistration;
    426   // Registers a CollectiveImplementation with name `collective_name` and
    427   // factory `factory`.  The latter is a function used to create instances of
    428   // the CollectiveImplementation.  Also creates a static instance of the
    429   // implementation - this instance is used during param resolution and should
    430   // only be used to call InitializeCollectiveParams.
    431   static Status Register(const string& collective_name, Factory factory);
    432 
    433   static Status LookupHelper(const string& collective_name,
    434                              CollectiveImplementationInterface** implementation,
    435                              bool param_resolver);
    436 };
    437 
    438 // Class used to call CollectiveRegistry::Register.  This should only be used to
    439 // create a global static object.
    440 class CollectiveRegistration {
    441  public:
    442   CollectiveRegistration(const string& collective_name,
    443                          CollectiveRegistry::Factory factory) {
    444     TF_CHECK_OK(CollectiveRegistry::Register(collective_name, factory));
    445   }
    446 };
    447 
    448 #define REGISTER_COLLECTIVE(name, implementation)             \
    449   static CollectiveRegistration register_##name##_collective( \
    450       #name, []() { return new implementation; });
    451 
    452 }  // namespace tensorflow
    453 
    454 #endif  // TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_
    455