1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_ 16 #define TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_ 17 18 #include <string> 19 #include <vector> 20 21 #include "tensorflow/core/framework/device_attributes.pb.h" 22 #include "tensorflow/core/framework/device_base.h" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/lib/core/refcount.h" 25 #include "tensorflow/core/lib/core/status.h" 26 27 namespace tensorflow { 28 class BufRendezvous; 29 class CancellationManager; 30 class CompleteGroupRequest; 31 class CompleteGroupResponse; 32 class CompleteInstanceRequest; 33 class CompleteInstanceResponse; 34 class Device; 35 class DeviceMgr; 36 class GetStepSequenceRequest; 37 class GetStepSequenceResponse; 38 class Op; 39 class Tensor; 40 41 // Types of supported collective operations. 42 enum CollectiveType { 43 REDUCTION_COLLECTIVE = 0, 44 BROADCAST_COLLECTIVE, 45 GATHER_COLLECTIVE, 46 UNDEFINED_COLLECTIVE, 47 }; 48 49 // Data common to all members of a device group. 50 // All members share the same device set but its order is 51 // particular to an instance so it is stored there. 52 struct CollGroupParams { 53 int32 group_key; 54 int32 group_size; 55 DeviceType device_type; 56 int32 num_tasks; // number of distinct tasks in group 57 string ToString() const; 58 CollGroupParams() 59 : group_key(0), group_size(0), device_type(DEVICE_CPU), num_tasks(0) {} 60 }; 61 62 // The best implementation of a collective op depends on many factors 63 // including the number of devices involved, the topology of 64 // interconnects between them and the sizes of inputs. This structure 65 // is used in generating and representing data movement choreography 66 // for each specific algorithm, hence it does not have a single, fixed 67 // interpretation. On first execution the runtime will update this 68 // structure with decisions that will guide all subsequent executions. 69 struct CollImplDetails { 70 string collective_name; 71 std::vector<std::vector<int>> subdiv_permutations; 72 std::vector<int> subdiv_offsets; 73 std::vector<int> subdiv_source_rank; // rank of source in each subdiv 74 std::vector<int32> 75 dependencies; // collective instances on which this node depends 76 }; 77 78 // Data common to all members of a collective instance. 79 struct CollInstanceParams { 80 // Identifies all participating graph nodes. 81 int32 instance_key = -1; 82 CollectiveType type = UNDEFINED_COLLECTIVE; 83 DataType data_type = DT_FLOAT; 84 TensorShape shape = {0}; 85 // Fully qualified name of device for each member, in default rank order. 86 std::vector<string> device_names; 87 // Task name prefix of corresponding device name. 88 std::vector<string> task_names; 89 // True if every task has the same number of devices. 90 bool same_num_devices_per_task = false; 91 // Task -> number of devices on that task. 92 std::unordered_map<string, int32> num_devices_per_task; 93 // If passed in to GPUOptions in ConfigProto, defines a good ring order for 94 // GPUs. Assumes same GPU configuration at each worker. 95 string gpu_ring_order = ""; 96 // Valid when using a communicator-based collective mechanism, e.g. NCCL. 97 string communicator_key; 98 CollImplDetails impl_details; 99 string ToString() const; 100 CollInstanceParams& operator=(const struct CollInstanceParams& other); 101 }; 102 103 // Data common to all instance members in the same task. 104 struct CollTaskParams { 105 // True for devices that are local to the process, i.e. no RPC needed. 106 std::vector<bool> is_local; 107 string ToString() const; 108 }; 109 110 // Unique to a single CollectiveOp node. 111 struct CollectiveParams { 112 CollGroupParams group; 113 CollInstanceParams instance; 114 CollTaskParams task; 115 116 string name = ""; // node name used only for log or error messages 117 int default_rank = -1; // index of this op within device_names 118 bool is_source = false; // broadcast only 119 int source_rank = -1; // broadcast only 120 // Rank of this device in each subdivision permutation. 121 std::vector<int> subdiv_rank; 122 std::unique_ptr<OpKernel> merge_op; // reduction only 123 std::unique_ptr<OpKernel> final_op; // reduction only 124 string ToString() const; 125 }; 126 127 class CollectiveExecutor; 128 129 // Interface that provides resolution of device localities. 130 class DeviceResolverInterface { 131 public: 132 virtual ~DeviceResolverInterface() {} 133 134 // Collects DeviceLocality protobufs from all of the devices identified 135 // in 'col_params'. 136 virtual void GetDeviceLocalitiesAsync(const CollInstanceParams& inst_params, 137 std::vector<DeviceLocality>* localities, 138 const StatusCallback& done) = 0; 139 140 // Populate *locality with the DeviceLocality of the specified 141 // device. 142 virtual void GetLocalityAsync(const string& device, const string& task, 143 DeviceLocality* locality, 144 const StatusCallback& done) = 0; 145 146 // Clear the cache of device data belonging 147 // to the specified task. 148 virtual void ClearTask(const string& task) = 0; 149 }; 150 151 // Interface that provides resolution of shared CollectiveParams fields. 152 class ParamResolverInterface { 153 public: 154 virtual ~ParamResolverInterface() {} 155 156 // Called by each collective op at first execution in order to fill out 157 // the CollectiveParams structure with data gathered from the full 158 // (maybe distributed) collection of peer nodes. 159 virtual void CompleteParamsAsync(const string& device, CollectiveParams* cp, 160 CancellationManager* cancel_mgr, 161 const StatusCallback& done) = 0; 162 163 // Used within a distributed implementation to discover/verify 164 // data shared across a device group. 165 virtual void CompleteGroupAsync(const CompleteGroupRequest* request, 166 CompleteGroupResponse* response, 167 CancellationManager* cancel_mgr, 168 const StatusCallback& done) = 0; 169 170 // Used within a distributed implementation to discover/verify data 171 // shared across an instance group. 172 virtual void CompleteInstanceAsync(const CompleteInstanceRequest* request, 173 CompleteInstanceResponse* response, 174 CancellationManager* cancel_mgr, 175 const StatusCallback& done) = 0; 176 }; 177 178 // Graphs which utilize Collective Ops in a common instance must 179 // execute with identical step_ids even if they are disjoint graphs 180 // run by otherwise independent tasks. This interface supplies 181 // coordinated step_ids to use in such cases. 182 class StepSequenceInterface { 183 public: 184 virtual ~StepSequenceInterface() {} 185 186 // Used with a distributed implementation to coordinate step_id 187 // sequences across tasks. 188 virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request, 189 GetStepSequenceResponse* response, 190 const StatusCallback& done) = 0; 191 192 // Refresh the local per-graph_key step_id sequence from collective 193 // group leader, if applicable. 194 virtual void RefreshStepIdSequenceAsync(int64 graph_key, 195 const StatusCallback& done) = 0; 196 197 // Returns the step_id that should be used for initiating a new execution 198 // on the specified graph. May return the same step_id multiple times if 199 // RetireStepId or RefreshStepIdReservation is not called. 200 virtual int64 NextStepId(int64 graph_key) = 0; 201 202 // Reports that execution of the given step has completed successfully. 203 // Should be called immediately after a step completes with OK status, 204 // prior to calling NextStepId(). If the step fails, don't call. 205 virtual void RetireStepId(int64 graph_key, int64 step_id) = 0; 206 }; 207 208 // Interface that provides access to per-step CollectiveExecutor 209 // instances and various distributed resolution capabilities. 210 class CollectiveExecutorMgrInterface : public StepSequenceInterface { 211 public: 212 virtual ~CollectiveExecutorMgrInterface() {} 213 214 // Returns the step-specific CollectiveExecutor, creating if one does not 215 // already exist. The caller assumes ownership of one Ref on the object. 216 virtual CollectiveExecutor* FindOrCreate(int64 step_id) = 0; 217 218 // If there is a CollectiveExecutor for step_id, remove it from the 219 // table. 220 virtual void Cleanup(int64 step_id) = 0; 221 222 virtual ParamResolverInterface* GetParamResolver() const = 0; 223 224 virtual DeviceResolverInterface* GetDeviceResolver() const = 0; 225 }; 226 227 // Interface that a Collective Op implementation uses to exchange data 228 // with peers. Note that data exchange is currently limited to types 229 // for which DMAHelper::CanUseDMA() returns true, i.e. dense numeric 230 // types. 231 class PeerAccessInterface { 232 public: 233 virtual ~PeerAccessInterface() {} 234 235 virtual void RecvFromPeer(const string& peer_device, const string& peer_task, 236 bool peer_is_local, const string& key, 237 Device* to_device, DeviceContext* to_device_ctx, 238 const AllocatorAttributes& to_alloc_attr, 239 Tensor* to_tensor, 240 const DeviceLocality& client_locality, 241 int dev_to_dev_stream_index, 242 const StatusCallback& done) = 0; 243 244 virtual void PostToPeer(const string& peer_device, const string& peer_task, 245 const string& key, Device* from_device, 246 DeviceContext* from_device_ctx, 247 const AllocatorAttributes& from_alloc_attr, 248 const Tensor* from_tensor, 249 const DeviceLocality& client_locality, 250 const StatusCallback& done) = 0; 251 }; 252 253 class PerStepCollectiveRemoteAccess; 254 255 // A step-specific object that can execute a collective operation completely 256 // described by a CollectiveParams object. 257 class CollectiveExecutor : public PeerAccessInterface, public core::RefCounted { 258 public: 259 virtual void StartAbort(const Status& s) {} 260 261 virtual void ExecuteAsync(OpKernelContext* ctx, 262 const CollectiveParams& col_params, 263 const string& exec_key, StatusCallback done) { 264 done(errors::Internal( 265 "A collective Op has been called in a context in which " 266 "a CollectiveExecutor has not been provided.")); 267 } 268 269 virtual void CompleteParamsAsync(const string& device, CollectiveParams* cp, 270 CancellationManager* cancel_mgr, 271 StatusCallback done) { 272 done(errors::Internal( 273 "A collective Op has been called in a context in which " 274 "a CollectiveExecutor has not been provided.")); 275 } 276 277 virtual PerStepCollectiveRemoteAccess* remote_access() { return nullptr; } 278 279 // `WaitForDependencies` and `Launched` are used for fine-grained control of 280 // execution order between collective instances. These functions are intended 281 // to be called in `Run` function of collective implementations, and may be 282 // used to make part, or whole, of the collective execution ordered with 283 // respect to other collective instances. 284 // 285 // `WaitForDependencies` will block until it is safe to continue the callee's 286 // execution, where safety is defined as: ordered with respect to the 287 // collective instances defined in the callee's `wait_for` attribute. 288 virtual void WaitForDependencies(const CollectiveParams& col_params) {} 289 // `Launched` unblocks the dependent collective instances by recording that 290 // this callee device has completed the critical portion of the collective 291 // execution. 292 virtual void Launched(const CollectiveParams& col_params) {} 293 294 // Used to designate an invalid group or instance key. 295 static int64 kInvalidId; 296 297 // Lexically scoped handle for Ref. 298 class Handle { 299 public: 300 explicit Handle(CollectiveExecutor* ce, bool inherit_ref) : ce_(ce) { 301 if (!inherit_ref) ce->Ref(); 302 } 303 ~Handle() { ce_->Unref(); } 304 CollectiveExecutor* get() const { return ce_; } 305 306 private: 307 CollectiveExecutor* ce_; 308 }; 309 310 protected: 311 explicit CollectiveExecutor(CollectiveExecutorMgrInterface* cem) 312 : cem_(cem) {} 313 314 // For use only by derived classes 315 static OpKernelContext::Params* CtxParams(OpKernelContext* ctx); 316 CollectiveExecutorMgrInterface* cem_; 317 318 TF_DISALLOW_COPY_AND_ASSIGN(CollectiveExecutor); 319 }; 320 321 // Interface of a helper object that provides a CollectiveExecutor with 322 // all of the remote access it needs. 323 class CollectiveRemoteAccess : public PeerAccessInterface, 324 public DeviceResolverInterface { 325 public: 326 virtual ~CollectiveRemoteAccess() {} 327 328 virtual BufRendezvous* buf_rendezvous() = 0; 329 }; 330 331 // A per-step version of CollectiveRemoteAccess that cleans up outstanding 332 // communications in case step execution is abandoned. 333 class PerStepCollectiveRemoteAccess : public CollectiveRemoteAccess { 334 public: 335 virtual ~PerStepCollectiveRemoteAccess() {} 336 virtual void StartAbort(const Status& s) = 0; 337 }; 338 339 class CollectiveContext { 340 public: 341 CollectiveContext(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr, 342 OpKernelContext* ctx, OpKernelContext::Params* op_params, 343 const CollectiveParams& col_params, const string& exec_key, 344 int64 step_id, const Tensor* input, Tensor* output); 345 346 virtual ~CollectiveContext() = default; 347 348 CollectiveExecutor* col_exec; // Not owned 349 const DeviceMgr* dev_mgr; // Not owned 350 OpKernelContext* op_ctx; // Not owned 351 OpKernelContext::Params* op_params; // Not owned 352 const CollectiveParams& col_params; 353 const string exec_key; 354 const int64 step_id; 355 const Tensor* input; // Not owned 356 Tensor* output; // Not owned 357 Device* device; // The device for which this instance labors 358 const string device_name; 359 DeviceLocality device_locality; 360 }; 361 362 // Interface of a Collective Op implementation. Each specific CollectiveOp will 363 // implement this interface and register the implementation via the 364 // CollectiveRegistry detailed below. See common_runtime/ring_reducer and 365 // common_runtime/hierarchical_tree_broadcaster for examples. 366 class CollectiveImplementationInterface { 367 public: 368 virtual ~CollectiveImplementationInterface() = default; 369 370 // Initializes the portions of `col_params` specific to this 371 // implementation. Called exactly once for every Collective instance during 372 // the CollectiveParams resolution process when the graph is first executed, 373 // at the end of `CompleteInstanceLocal()`. 374 // NOTE(ayushd): This is effectively a static function because it modifies the 375 // `col_params` passed in and should not manipulate any data members. However 376 // because it is virtual and needs to be implemented by every derived class we 377 // do not mark it as static. 378 virtual Status InitializeCollectiveParams(CollectiveParams* col_params) = 0; 379 380 // Prepares the CollectiveContext for executing this CollectiveImplementation. 381 // Called from CollectiveExecutor right before calling Run(). The 382 // CollectiveContext passed in must outlive the CollectiveImplementation 383 // object. 384 virtual Status InitializeCollectiveContext(CollectiveContext* col_ctx) = 0; 385 386 // Initializes instance params at the beginning of `CompleteInstanceLocal()`, 387 // unlike `InitializeCollectiveParams` which is called at the end. This 388 // function is called before all devices in the instance are discovered, and 389 // may be used to broadcast data via the shared `InstanceRec` object in 390 // collective param resolution to all devices. 391 virtual Status InitializeInstanceBeforeGroupDiscovery( 392 CollectiveParams* col_params) = 0; 393 394 // Processes and moves data according to the logic of this Collective 395 // implementation. Relies on appropriate initialization of op-specific 396 // CollectiveParams in InitializeCollectiveParams(), as well as appropriate 397 // context initialization in InitializeCollectiveContext(). 398 virtual void Run(StatusCallback done) = 0; 399 }; 400 401 // Static-methods only class for registering and looking up collective 402 // implementations. 403 class CollectiveRegistry { 404 public: 405 using Factory = std::function<CollectiveImplementationInterface*()>; 406 // Looks up a previously registered CollectiveImplementation under 407 // `collective_name`. If found, creates an instance of the implementation and 408 // assign to `implementation`. 409 static Status Lookup(const string& collective_name, 410 CollectiveImplementationInterface** implementation); 411 412 // Looks up a previously registered CollectiveImplementation under 413 // `collective_name`. If found, returns the static instance of this 414 // implementation via `implementation`. This instance should only be used to 415 // call InitializateCollectiveParams. 416 static Status LookupParamResolverInstance( 417 const string& collective_name, 418 CollectiveImplementationInterface** implementation); 419 420 // Returns all registered collective implementations. 421 static void GetAll( 422 std::vector<CollectiveImplementationInterface*>* implementations); 423 424 private: 425 friend class CollectiveRegistration; 426 // Registers a CollectiveImplementation with name `collective_name` and 427 // factory `factory`. The latter is a function used to create instances of 428 // the CollectiveImplementation. Also creates a static instance of the 429 // implementation - this instance is used during param resolution and should 430 // only be used to call InitializeCollectiveParams. 431 static Status Register(const string& collective_name, Factory factory); 432 433 static Status LookupHelper(const string& collective_name, 434 CollectiveImplementationInterface** implementation, 435 bool param_resolver); 436 }; 437 438 // Class used to call CollectiveRegistry::Register. This should only be used to 439 // create a global static object. 440 class CollectiveRegistration { 441 public: 442 CollectiveRegistration(const string& collective_name, 443 CollectiveRegistry::Factory factory) { 444 TF_CHECK_OK(CollectiveRegistry::Register(collective_name, factory)); 445 } 446 }; 447 448 #define REGISTER_COLLECTIVE(name, implementation) \ 449 static CollectiveRegistration register_##name##_collective( \ 450 #name, []() { return new implementation; }); 451 452 } // namespace tensorflow 453 454 #endif // TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_ 455