1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/core/common_runtime/base_collective_executor.h" 16 17 #include <algorithm> 18 #include <functional> 19 #include <utility> 20 21 #include "tensorflow/core/common_runtime/copy_tensor.h" 22 #include "tensorflow/core/common_runtime/device_mgr.h" 23 #include "tensorflow/core/common_runtime/dma_helper.h" 24 #include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h" 25 #include "tensorflow/core/common_runtime/process_util.h" 26 #include "tensorflow/core/common_runtime/ring_reducer.h" 27 #include "tensorflow/core/framework/allocator.h" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/framework/tensor.h" 30 #include "tensorflow/core/framework/tensor_shape.h" 31 #include "tensorflow/core/framework/types.h" 32 #include "tensorflow/core/framework/types.pb.h" 33 #include "tensorflow/core/lib/core/errors.h" 34 #include "tensorflow/core/lib/core/notification.h" 35 #include "tensorflow/core/lib/core/status.h" 36 #include "tensorflow/core/lib/strings/strcat.h" 37 #include "tensorflow/core/platform/macros.h" 38 #include "tensorflow/core/platform/types.h" 39 40 #define VALUE_IN_DEBUG_STRING false 41 42 namespace tensorflow { 43 /*static*/ 44 int64 CollectiveAdapter::AlignedChunkElts(int64 elt_bytes, int64 total_elts, 45 int64 num_chunks) { 46 DCHECK_GT(num_chunks, 0); 47 int64 base_chunk_elts = (total_elts + (num_chunks - 1)) / num_chunks; 48 if (EIGEN_MAX_ALIGN_BYTES == 0) return base_chunk_elts; 49 if (EIGEN_MAX_ALIGN_BYTES <= elt_bytes) { 50 // Tolerate weird small values of EIGEN_MAX_ALIGN_BYTES 51 DCHECK_EQ(0, elt_bytes % EIGEN_MAX_ALIGN_BYTES); 52 return base_chunk_elts; 53 } 54 // elt_bytes < EIGEN_MAX_ALIGN_BYTES, which 55 // must be a common multiple of the various atomic data types. 56 DCHECK_EQ(0, EIGEN_MAX_ALIGN_BYTES % elt_bytes) 57 << "total_elts=" << total_elts << " num_chunks=" << num_chunks 58 << " EIGEN_MAX_ALIGN_BYTES=" << EIGEN_MAX_ALIGN_BYTES 59 << " elt_bytes=" << elt_bytes; 60 // Round bytes per chunk up to the next multiple of EIGEN_MAX_ALIGN_BYTES. 61 int64 chunk_bytes = base_chunk_elts * elt_bytes; 62 int64 diff = 63 (chunk_bytes < EIGEN_MAX_ALIGN_BYTES) 64 ? (EIGEN_MAX_ALIGN_BYTES - chunk_bytes) 65 : (EIGEN_MAX_ALIGN_BYTES - (chunk_bytes % EIGEN_MAX_ALIGN_BYTES)); 66 DCHECK_EQ(0, diff % elt_bytes); 67 base_chunk_elts += (diff / elt_bytes); 68 DCHECK_EQ(0, ((base_chunk_elts * elt_bytes) % EIGEN_MAX_ALIGN_BYTES)) 69 << "total_elts=" << total_elts << " num_chunks=" << num_chunks 70 << " EIGEN_MAX_ALIGN_BYTES=" << EIGEN_MAX_ALIGN_BYTES 71 << " base_chunk_elts=" << base_chunk_elts << " elt_bytes=" << elt_bytes; 72 return base_chunk_elts; 73 } 74 75 namespace { 76 template <typename T> 77 class CollectiveAdapterImpl : public CollectiveAdapter { 78 public: 79 // Takes ownership of output and prepares to properly alias its chunks. 80 // Ownership is taken because the shape may temporarily change. 81 CollectiveAdapterImpl(Tensor* output, int64 num_chunks, Allocator* allocator, 82 bool align_chunks) 83 : output_(std::move(*output)), 84 dt_(output_.dtype()), 85 old_shape_(output_.shape()), 86 num_chunks_(num_chunks), 87 allocator_(allocator), 88 total_elts_(output_.NumElements()), 89 chunk_elts_(align_chunks 90 ? AlignedChunkElts(sizeof(T), total_elts_, num_chunks_) 91 : total_elts_ / num_chunks_), 92 data_start_(reinterpret_cast<T*>(DMAHelper::base(&output_))), 93 data_end_(data_start_ + total_elts_) { 94 if (!align_chunks) { 95 DCHECK_EQ(total_elts_, num_chunks_ * chunk_elts_); 96 } 97 DCHECK_GT(chunk_elts_, 0); 98 Flatten(); 99 } 100 101 ~CollectiveAdapterImpl() override {} 102 103 const Tensor& Value() const override { return output_; } 104 105 // If necessary, flatten output. 106 void Flatten() { 107 if (old_shape_.dims() != 1) { 108 TensorShape new_shape = TensorShape({old_shape_.num_elements()}); 109 DMAHelper::UnsafeSetShape(&output_, new_shape); 110 } 111 } 112 113 void ConsumeFinalValue(Tensor* output) override { 114 if (old_shape_ != output_.shape()) { 115 DMAHelper::UnsafeSetShape(&output_, old_shape_); 116 } 117 *output = std::move(output_); 118 } 119 120 // Number of T elements in a particular chunk. 121 inline int64 ChunkElts(int i) const { 122 DCHECK_LT(i, num_chunks_); 123 const T* chunk_start = std::min(data_end_, data_start_ + i * chunk_elts_); 124 const T* chunk_end = std::min(data_end_, chunk_start + chunk_elts_); 125 return chunk_end - chunk_start; 126 } 127 128 int64 ChunkBytes(int i) const override { return sizeof(T) * ChunkElts(i); } 129 130 // Returns a new Tensor that aliases the required chunk. 131 Tensor ChunkAlias(int i) override { 132 int64 start = chunk_elts_ * i; 133 int64 num_elts = ChunkElts(i); 134 // If this chunk is empty the prior chunk might also be short 135 // so always take an empty slice from the front of the tensor 136 // to avoid an illegal offset check failure somewhere. 137 return (num_elts > 0) ? output_.Slice(start, start + num_elts) 138 : output_.Slice(0, 0); 139 } 140 141 Tensor TempChunk(int i) const override { 142 AllocationAttributes empty; 143 return Tensor(allocator_, dt_, {ChunkElts(i)}, empty); 144 } 145 146 string DebugString() const override { 147 return strings::StrCat( 148 "base addr ", reinterpret_cast<int64>(DMAHelper::base(&output_)), 149 " num_chunks ", num_chunks_, " total_elts ", total_elts_, " chunk_elts", 150 chunk_elts_, " value ", 151 VALUE_IN_DEBUG_STRING ? output_.SummarizeValue(1024) : "<hidden>"); 152 } 153 154 string TBounds(const Tensor& t) const override { 155 int64 base_addr = reinterpret_cast<int64>(DMAHelper::base(&t)); 156 return strings::StrCat("(", base_addr, ", ", (base_addr + t.TotalBytes()), 157 ")"); 158 } 159 160 Tensor Scalar(int v) const override { 161 Tensor t(dt_, TensorShape({})); 162 t.scalar<T>()() = v; 163 return t; 164 } 165 166 Tensor Scalar(Allocator* a) const override { 167 Tensor t(a, dt_, TensorShape({})); 168 return t; 169 } 170 171 Tensor output_; 172 const DataType dt_; 173 const TensorShape old_shape_; 174 const int64 num_chunks_; 175 Allocator* allocator_; 176 const int64 total_elts_; 177 const int64 chunk_elts_; 178 const T* data_start_; 179 const T* data_end_; 180 }; 181 182 } // namespace 183 184 CollectiveAdapter* MakeCollectiveAdapter(Tensor* output, int num_chunks, 185 Allocator* allocator, 186 bool align_chunks) { 187 switch (output->dtype()) { 188 case DT_FLOAT: 189 return new CollectiveAdapterImpl<float>(output, num_chunks, allocator, 190 align_chunks); 191 break; 192 case DT_DOUBLE: 193 return new CollectiveAdapterImpl<double>(output, num_chunks, allocator, 194 align_chunks); 195 break; 196 case DT_INT32: 197 return new CollectiveAdapterImpl<int32>(output, num_chunks, allocator, 198 align_chunks); 199 break; 200 case DT_INT64: 201 return new CollectiveAdapterImpl<int64>(output, num_chunks, allocator, 202 align_chunks); 203 break; 204 default: 205 LOG(FATAL) << "Unsupported type " << output->dtype() 206 << " to MakeCollectiveAdapter"; 207 return nullptr; 208 } 209 } 210 211 BaseCollectiveExecutor::~BaseCollectiveExecutor() {} 212 213 void BaseCollectiveExecutor::StartAbort(const Status& s) { 214 LOG(WARNING) << "BaseCollectiveExecutor::StartAbort " << s; 215 remote_access_->StartAbort(s); 216 } 217 218 void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx, 219 const CollectiveParams& col_params, 220 const string& exec_key, 221 StatusCallback done) { 222 // On any individual collective Op failure we need to abort the 223 // BufRendezvous so that other Ops in the instance don't hang 224 // waiting for transmissions that will never happen. Do so after a 225 // delay so that the original error status is more likely to 226 // propagate up, and peers are unlikely to re-create the purged 227 // BufRendezvous by late-arriving requests. 228 StatusCallback done_safe = [this, done](const Status& s) { 229 if (!s.ok()) { 230 Ref(); // Ensure this lasts until the closure executes. 231 SchedNonBlockingClosureAfter(1000000, [this, s] { 232 remote_access_->buf_rendezvous()->StartAbort(s); 233 Unref(); 234 }); 235 } 236 done(s); 237 }; 238 239 Tensor* output = ctx->mutable_output(0); 240 const Tensor* input = (col_params.instance.type == REDUCTION_COLLECTIVE || 241 col_params.instance.type == GATHER_COLLECTIVE || 242 (col_params.instance.type == BROADCAST_COLLECTIVE && 243 col_params.is_source)) 244 ? &ctx->input(0) 245 : nullptr; 246 CollectiveImplementationInterface* col_impl = nullptr; 247 Status status = CreateCollective(col_params, &col_impl); 248 if (!status.ok()) { 249 done_safe(status); 250 DCHECK_EQ(nullptr, col_impl); 251 return; 252 } 253 CollectiveContext* col_ctx = 254 new CollectiveContext(this, dev_mgr_, ctx, CtxParams(ctx), col_params, 255 exec_key, step_id_, input, output); 256 status = col_impl->InitializeCollectiveContext(col_ctx); 257 if (!status.ok()) { 258 done_safe(status); 259 delete col_ctx; 260 delete col_impl; 261 return; 262 } 263 // Run in an I/O thread, so as not to starve the executor threads. 264 // TODO(b/80529858): Instead of forking every per-device Collective 265 // Op off into its own thread, consider queuing them on a 266 // fixed-size thread-pool dedicated to running CollectiveOps. 267 SchedClosure([col_impl, col_ctx, done_safe]() { 268 col_impl->Run([col_impl, col_ctx, done_safe](const Status& s) { 269 done_safe(s); 270 delete col_ctx; 271 delete col_impl; 272 }); 273 }); 274 } 275 276 void BaseCollectiveExecutor::CompleteParamsAsync( 277 const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr, 278 StatusCallback done) { 279 cp->instance.gpu_ring_order = *gpu_ring_order_; 280 cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr, done); 281 } 282 283 Status BaseCollectiveExecutor::CreateCollective( 284 const CollectiveParams& col_params, 285 CollectiveImplementationInterface** col_impl) { 286 *col_impl = nullptr; 287 Status status; 288 switch (col_params.instance.data_type) { 289 case DT_INT32: 290 if (col_params.group.device_type == DEVICE_GPU) { 291 status = errors::Internal( 292 "CollectiveImplementation does not support datatype DT_INT32 on " 293 "DEVICE_GPU"); 294 } 295 TF_FALLTHROUGH_INTENDED; 296 case DT_FLOAT: 297 case DT_DOUBLE: 298 case DT_INT64: { 299 status = CollectiveRegistry::Lookup( 300 col_params.instance.impl_details.collective_name, col_impl); 301 break; 302 } 303 default: 304 status = errors::Internal( 305 "CollectiveImplementation does not support datatype ", 306 col_params.instance.data_type); 307 } 308 return status; 309 } 310 311 bool BaseCollectiveExecutor::CheckDependencies( 312 const CollectiveParams& col_params) { 313 for (int32 instance : col_params.instance.impl_details.dependencies) { 314 auto find_iter = launched_.find(instance); 315 if (find_iter == launched_.end() || find_iter->second != 0) { 316 VLOG(1) << "Collective " << col_params.ToString() 317 << " blocked by instance " << instance; 318 return false; 319 } 320 } 321 return true; 322 } 323 324 void BaseCollectiveExecutor::WaitForDependencies( 325 const CollectiveParams& col_params) { 326 mutex_lock l(launch_mu_); 327 while (!CheckDependencies(col_params)) { 328 launch_cv_.wait(l); 329 } 330 VLOG(1) << "Unblocking collective " << col_params.ToString(); 331 } 332 333 void BaseCollectiveExecutor::Launched(const CollectiveParams& col_params) { 334 mutex_lock l(launch_mu_); 335 if (launched_.find(col_params.instance.instance_key) == launched_.end()) { 336 const string& task_name = 337 col_params.instance.task_names[col_params.default_rank]; 338 const int32 num_devices = 339 col_params.instance.num_devices_per_task.at(task_name); 340 launched_[col_params.instance.instance_key] = num_devices; 341 } 342 if (--launched_[col_params.instance.instance_key] == 0) { 343 VLOG(1) << "Unblocking dependencies for collective instance " 344 << col_params.instance.instance_key; 345 launch_cv_.notify_all(); 346 } 347 } 348 349 } // namespace tensorflow 350