1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" 17 18 #include <cmath> 19 20 #include "tensorflow/compiler/xla/shape_util.h" 21 #include "tensorflow/compiler/xla/status_macros.h" 22 #include "tensorflow/compiler/xla/util.h" 23 #include "tensorflow/core/lib/core/bits.h" 24 #include "tensorflow/core/lib/core/errors.h" 25 #include "tensorflow/core/lib/gtl/map_util.h" 26 27 namespace xla { 28 29 constexpr char HloCostAnalysis::kFlopsKey[]; 30 constexpr char HloCostAnalysis::kTranscendentalsKey[]; 31 constexpr char HloCostAnalysis::kBytesAccessedKey[]; 32 constexpr char HloCostAnalysis::kOptimalSecondsKey[]; 33 34 HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size) 35 : HloCostAnalysis(shape_size, {}) {} 36 37 HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size, 38 const Properties& per_second_rates) 39 : shape_size_(shape_size), per_second_rates_(per_second_rates) {} 40 41 Status HloCostAnalysis::Preprocess(const HloInstruction* hlo) { 42 // Set current instruction cost values to reasonable default values. Each 43 // handler can overwrite these values. In Postprocess, these values are 44 // accumulated and written to the per-instruction maps. 45 current_properties_.clear(); 46 current_should_compute_bottleneck_time_ = true; 47 48 // The default number of bytes accessed for an instruction is the sum of the 49 // sizes of the inputs and outputs. The default ShapeUtil::ByteSizeOf does not 50 // handle opaque types. 51 float bytes_accessed = shape_size_(hlo->shape()); 52 for (const HloInstruction* operand : hlo->operands()) { 53 bytes_accessed += shape_size_(operand->shape()); 54 } 55 current_properties_[kBytesAccessedKey] = bytes_accessed; 56 57 return Status::OK(); 58 } 59 60 Status HloCostAnalysis::Postprocess(const HloInstruction* hlo) { 61 if (current_should_compute_bottleneck_time_) { 62 // Compute the time as the time of the bottleneck, i.e. the slowest property 63 // given the per-second rate of each property. 64 float optimal_seconds = 0.0f; 65 for (const auto& property : current_properties_) { 66 if (property.first != kOptimalSecondsKey) { 67 optimal_seconds = std::max( 68 optimal_seconds, 69 property.second / 70 GetProperty(property.first, per_second_rates_, INFINITY)); 71 } 72 } 73 current_properties_[kOptimalSecondsKey] = optimal_seconds; 74 } 75 76 TF_RET_CHECK(hlo_properties_.emplace(hlo, current_properties_).second); 77 for (const auto& property : current_properties_) { 78 properties_sum_[property.first] += property.second; 79 } 80 81 return Status::OK(); 82 } 83 84 Status HloCostAnalysis::HandleElementwiseOp( 85 const HloInstruction* hlo_instruction) { 86 const auto& shape = hlo_instruction->shape(); 87 // For element-wise operations, the number of computations is the same as the 88 // number of elements in the output shape. 89 auto computation_count = ShapeUtil::ElementsIn(shape); 90 auto opcode = hlo_instruction->opcode(); 91 // We treat transcendental operations separately since one transcendental 92 // operation can correspond to several floating point ops. 93 if (opcode == HloOpcode::kExp || opcode == HloOpcode::kPower || 94 opcode == HloOpcode::kTanh || opcode == HloOpcode::kSin || 95 opcode == HloOpcode::kCos) { 96 current_properties_[kTranscendentalsKey] = computation_count; 97 } else { 98 // Note: transcendental operations are considered a separate category from 99 // FLOPs. 100 current_properties_[kFlopsKey] = computation_count; 101 } 102 return Status::OK(); 103 } 104 105 /*static*/ float HloCostAnalysis::GetProperty(const string& key, 106 const Properties& properties, 107 const float default_value) { 108 auto key_value = properties.find(key); 109 return key_value == properties.end() ? default_value : key_value->second; 110 } 111 112 /*static*/ float HloCostAnalysis::GetPropertyForHlo( 113 const HloInstruction& hlo, const string& key, 114 const HloToProperties& hlo_to_properties) { 115 auto it = hlo_to_properties.find(&hlo); 116 if (it == hlo_to_properties.end()) { 117 return 0.0f; 118 } else { 119 return GetProperty(key, it->second); 120 } 121 } 122 123 Status HloCostAnalysis::HandleElementwiseUnary(const HloInstruction* hlo) { 124 return HandleElementwiseOp(hlo); 125 } 126 127 Status HloCostAnalysis::HandleElementwiseBinary(const HloInstruction* hlo) { 128 return HandleElementwiseOp(hlo); 129 } 130 131 Status HloCostAnalysis::HandleCompare(const HloInstruction* compare) { 132 return HandleElementwiseOp(compare); 133 } 134 135 Status HloCostAnalysis::HandleClamp(const HloInstruction* clamp) { 136 return HandleElementwiseOp(clamp); 137 } 138 139 Status HloCostAnalysis::HandleReducePrecision(const HloInstruction* hlo) { 140 return HandleElementwiseOp(hlo); 141 } 142 143 Status HloCostAnalysis::HandleParameter(const HloInstruction*) { 144 current_properties_[kBytesAccessedKey] = 0; 145 return Status::OK(); 146 } 147 148 Status HloCostAnalysis::HandleConstant(const HloInstruction*) { 149 current_properties_[kBytesAccessedKey] = 0; 150 return Status::OK(); 151 } 152 153 Status HloCostAnalysis::HandleGetTupleElement(const HloInstruction*) { 154 // GetTupleElement forwards a pointer and does not touch each element in the 155 // output. 156 current_properties_[kBytesAccessedKey] = 0; 157 return Status::OK(); 158 } 159 160 Status HloCostAnalysis::HandleSelect(const HloInstruction*) { 161 return Status::OK(); 162 } 163 164 Status HloCostAnalysis::HandleReverse(const HloInstruction*) { 165 return Status::OK(); 166 } 167 168 Status HloCostAnalysis::HandleSlice(const HloInstruction*) { 169 return Status::OK(); 170 } 171 172 Status HloCostAnalysis::HandleDynamicSlice(const HloInstruction*) { 173 return Status::OK(); 174 } 175 176 Status HloCostAnalysis::HandleDynamicUpdateSlice(const HloInstruction*) { 177 return Status::OK(); 178 } 179 180 Status HloCostAnalysis::HandleTuple(const HloInstruction* tuple) { 181 // The tuple instruction only gathers pointers from inputs (it doesn't iterate 182 // through them). The memory touched is then only the size of the output 183 // index table of the tuple. 184 185 current_properties_[kBytesAccessedKey] = shape_size_(tuple->shape()); 186 return Status::OK(); 187 } 188 189 Status HloCostAnalysis::HandleConcatenate(const HloInstruction*) { 190 return Status::OK(); 191 } 192 193 Status HloCostAnalysis::HandleConvert(const HloInstruction* convert) { 194 return HandleElementwiseOp(convert); 195 } 196 197 Status HloCostAnalysis::HandleCopy(const HloInstruction*) { 198 return Status::OK(); 199 } 200 201 Status HloCostAnalysis::HandleDot(const HloInstruction* dot) { 202 const Shape& lhs_shape = dot->operand(0)->shape(); 203 const Shape& rhs_shape = dot->operand(1)->shape(); 204 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); 205 // Count of elements along the reduction dimension (last dimension for the 206 // rhs). 207 int64 reduction_width = 208 lhs_shape.dimensions(dnums.lhs_contracting_dimensions(0)); 209 // First divide by reduction width before multiplying by rhs elements to avoid 210 // overflow. 211 int64 fma_count; 212 if (reduction_width == 0) { 213 fma_count = 0; 214 } else { 215 fma_count = (ShapeUtil::ElementsIn(lhs_shape) / reduction_width) * 216 ShapeUtil::ElementsIn(rhs_shape); 217 } 218 219 // We count an FMA operation as 2 floating point operations. 220 current_properties_[kFlopsKey] = kFmaFlops * fma_count; 221 return Status::OK(); 222 } 223 224 Status HloCostAnalysis::HandleInfeed(const HloInstruction*) { 225 return Status::OK(); 226 } 227 228 Status HloCostAnalysis::HandleOutfeed(const HloInstruction*) { 229 return Status::OK(); 230 } 231 232 Status HloCostAnalysis::HandleHostCompute(const HloInstruction*) { 233 return Status::OK(); 234 } 235 236 Status HloCostAnalysis::HandleMap(const HloInstruction* map) { 237 // Compute properties of the mapped function. 238 TF_ASSIGN_OR_RETURN(const Properties sub_properties, 239 ProcessSubcomputation(map->to_apply())); 240 241 // Compute the cost of all elements for this Map operation. 242 const int64 element_count = ShapeUtil::ElementsIn(map->shape()); 243 for (const auto& property : sub_properties) { 244 if (property.first != kBytesAccessedKey) { 245 current_properties_[property.first] = property.second * element_count; 246 } 247 } 248 return Status::OK(); 249 } 250 251 Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) { 252 auto arg = reduce->operand(0); 253 HloComputation* function = reduce->to_apply(); 254 // Compute the cost of the user function. 255 TF_ASSIGN_OR_RETURN(const Properties sub_properties, 256 ProcessSubcomputation(function)); 257 258 // Compute the cost of all elements for this Reduce operation. 259 int64 reduction_count = ShapeUtil::ElementsIn(arg->shape()) - 260 ShapeUtil::ElementsIn(reduce->shape()); 261 for (const auto& property : sub_properties) { 262 if (property.first != kBytesAccessedKey) { 263 current_properties_[property.first] = property.second * reduction_count; 264 } 265 } 266 return Status::OK(); 267 } 268 269 Status HloCostAnalysis::HandleReduceWindow( 270 const HloInstruction* reduce_window) { 271 const Window& window = reduce_window->window(); 272 auto function = reduce_window->to_apply(); 273 // Compute the properties of the reduction function. 274 TF_ASSIGN_OR_RETURN(const Properties sub_properties, 275 ProcessSubcomputation(function)); 276 277 // Compute the cost of all elements for this ReduceWindow operation. For each 278 // output element there are window_size - 1 reductions to perform. 279 int64 window_element_count = 1; 280 for (const auto& dimension : window.dimensions()) { 281 window_element_count *= dimension.size(); 282 } 283 const int64 output_element_count = 284 ShapeUtil::ElementsIn(reduce_window->shape()); 285 const int64 reduction_count = 286 (window_element_count - 1) * output_element_count; 287 for (const auto& property : sub_properties) { 288 if (property.first != kBytesAccessedKey) { 289 current_properties_[property.first] = property.second * reduction_count; 290 } 291 } 292 return Status::OK(); 293 } 294 295 Status HloCostAnalysis::HandleSelectAndScatter( 296 const HloInstruction* instruction) { 297 // Compute the properties of the select and scatter function. 298 // Compute the properties of the reduction function. 299 TF_ASSIGN_OR_RETURN(const Properties select_properties, 300 ProcessSubcomputation(instruction->select())); 301 TF_ASSIGN_OR_RETURN(const Properties scatter_properties, 302 ProcessSubcomputation(instruction->scatter())); 303 304 // Compute the cost of all elements for this operation. For each scatter 305 // source element there are window_size - 1 select computations to perform and 306 // 1 scatter computation to perform. 307 const auto source = instruction->operand(1); 308 const auto source_element_count = ShapeUtil::ElementsIn(source->shape()); 309 int64 window_element_count = 1; 310 for (const auto& dimension : instruction->window().dimensions()) { 311 window_element_count *= dimension.size(); 312 } 313 const int64 select_count = source_element_count * (window_element_count - 1); 314 for (const auto& property : select_properties) { 315 if (property.first != kBytesAccessedKey) { 316 current_properties_[property.first] += property.second * select_count; 317 } 318 } 319 for (const auto& property : scatter_properties) { 320 if (property.first != kBytesAccessedKey) { 321 current_properties_[property.first] += 322 property.second * source_element_count; 323 } 324 } 325 return Status::OK(); 326 } 327 328 Status HloCostAnalysis::HandleBitcast(const HloInstruction*) { 329 // A bitcast does no computation and touches no memory. 330 current_properties_[kBytesAccessedKey] = 0; 331 return Status::OK(); 332 } 333 334 Status HloCostAnalysis::HandleBroadcast(const HloInstruction*) { 335 return Status::OK(); 336 } 337 338 Status HloCostAnalysis::HandlePad(const HloInstruction*) { 339 return Status::OK(); 340 } 341 342 Status HloCostAnalysis::HandleSend(const HloInstruction*) { 343 return Status::OK(); 344 } 345 346 Status HloCostAnalysis::HandleSendDone(const HloInstruction*) { 347 return Status::OK(); 348 } 349 350 Status HloCostAnalysis::HandleRecv(const HloInstruction*) { 351 return Status::OK(); 352 } 353 354 Status HloCostAnalysis::HandleRecvDone(const HloInstruction*) { 355 return Status::OK(); 356 } 357 358 Status HloCostAnalysis::HandleReshape(const HloInstruction*) { 359 return Status::OK(); 360 } 361 362 Status HloCostAnalysis::HandleBatchNormTraining(const HloInstruction*) { 363 // TODO(b/62294698): Implement cost analysis for batch-norm-training. 364 return Status::OK(); 365 } 366 367 Status HloCostAnalysis::HandleBatchNormInference(const HloInstruction*) { 368 // TODO(b/62294698): Implement cost analysis for batch-norm-inference. 369 return Status::OK(); 370 } 371 372 Status HloCostAnalysis::HandleBatchNormGrad(const HloInstruction*) { 373 // TODO(b/62294698): Implement cost analysis for batch-norm-grad. 374 return Status::OK(); 375 } 376 377 Status HloCostAnalysis::HandleTranspose(const HloInstruction*) { 378 return Status::OK(); 379 } 380 381 Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) { 382 auto rhs_instruction = convolution->operand(1); 383 const auto& dnums = convolution->convolution_dimension_numbers(); 384 const int64 output_features = 385 convolution->shape().dimensions(dnums.output_feature_dimension()); 386 387 // For each output element, we do one fma per element in the kernel at some 388 // given output feature index. 389 const int64 fmas_per_output_element = 390 output_features > 0 391 ? ShapeUtil::ElementsIn(rhs_instruction->shape()) / output_features 392 : 0; 393 const int64 output_elements = ShapeUtil::ElementsIn(convolution->shape()); 394 current_properties_[kFlopsKey] = 395 output_elements * fmas_per_output_element * kFmaFlops; 396 return Status::OK(); 397 } 398 399 Status HloCostAnalysis::HandleFft(const HloInstruction* fft) { 400 auto real_shape = 401 ShapeUtil::IsTuple(fft->operand(0)->shape()) 402 ? ShapeUtil::GetTupleElementShape(fft->operand(0)->shape(), 0) 403 : fft->operand(0)->shape(); 404 constexpr int kFmaPerComplexMul = 4; 405 int64 log_factors = 1; 406 for (int64 dim : fft->fft_length()) { 407 log_factors *= tensorflow::Log2Floor(dim); 408 } 409 current_properties_[kFlopsKey] = kFmaFlops * kFmaPerComplexMul * log_factors * 410 ShapeUtil::ElementsIn(real_shape); 411 return Status::OK(); 412 } 413 414 Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) { 415 // We assume 2 replicas, so that each output element is the sum of two input 416 // elements. 417 // 418 // TODO(b/33004697): Compute correct cost here, taking the actual number of 419 // replicas into account. 420 double flops = 0.0; 421 ShapeUtil::ForEachSubshape( 422 crs->shape(), [&, this](const Shape& subshape, const ShapeIndex&) { 423 if (ShapeUtil::IsArray(subshape)) { 424 flops += ShapeUtil::ElementsIn(subshape); 425 } 426 }); 427 current_properties_[kFlopsKey] = flops; 428 return Status::OK(); 429 } 430 431 Status HloCostAnalysis::HandleRng(const HloInstruction* random) { 432 // TODO(b/26346211): Implement better estimates for the RNG cost, since the 433 // cost changes with the implementation and the distribution. For now, assume 434 // the cost of each RNG is same as a transcendental operation. 435 current_properties_[kTranscendentalsKey] = 436 ShapeUtil::ElementsIn(random->shape()); 437 return Status::OK(); 438 } 439 440 Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { 441 // Compute the properties of the fused expression and attribute them to the 442 // fusion node. Use a dummy shape_size to avoid any errors from trying to 443 // calculate the size of a shape that does not have a layout, since nodes 444 // inside fusion nodes do not necessarily have a layout assigned. 445 ShapeSizeFunction shape_size = [](const Shape& shape) { return 0; }; 446 TF_ASSIGN_OR_RETURN( 447 current_properties_, 448 ProcessSubcomputation(fusion->fused_instructions_computation(), 449 &shape_size)); 450 451 // Fusion nodes that produce a tuple also produce the entries in the tuple. 452 // Ignore the memory accessed inside fused ops, since fusion is supposed to 453 // prevent intermediate data from touching slow memory. 454 current_properties_[kBytesAccessedKey] = 0; 455 ShapeUtil::ForEachSubshape( 456 fusion->shape(), 457 [this](const Shape& subshape, const ShapeIndex& /*shape_index*/) { 458 current_properties_[kBytesAccessedKey] += shape_size_(subshape); 459 }); 460 461 for (const HloInstruction* operand : fusion->operands()) { 462 current_properties_[kBytesAccessedKey] += shape_size_(operand->shape()); 463 } 464 465 return Status::OK(); 466 } 467 468 Status HloCostAnalysis::HandleCall(const HloInstruction* call) { 469 TF_ASSIGN_OR_RETURN(current_properties_, 470 ProcessSubcomputation(call->to_apply())); 471 current_should_compute_bottleneck_time_ = false; 472 return Status::OK(); 473 } 474 475 Status HloCostAnalysis::HandleCustomCall(const HloInstruction*) { 476 // We can't do anything sane with CustomCalls, since we don't know what they 477 // do, and returning an error status will stop iteration over this 478 // computation, which is probably also not what we want. So just punt and 479 // return OK. This will cause all of the properties to be reported as 0, 480 // which is fine. 481 current_should_compute_bottleneck_time_ = false; 482 return Status::OK(); 483 } 484 485 Status HloCostAnalysis::HandleSort(const HloInstruction* sort) { 486 // This assumes a comparison based N*log(N) algorithm. As for all ops, the 487 // actual properties of the op depend on the backend implementation. 488 int64 elements = ShapeUtil::ElementsIn(sort->operand(0)->shape()); 489 current_properties_[kFlopsKey] = elements * tensorflow::Log2Ceiling(elements); 490 return Status::OK(); 491 } 492 493 Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) { 494 // Since the number of iterations of the while node will not always be 495 // something that we can statically analyze, we cannot precisely compute the 496 // cost of a while node. For now compute the cost of a single iteration. 497 // 498 // TODO(b/26346211): Improve the cost analysis for while nodes. 499 TF_ASSIGN_OR_RETURN(const Properties body_properties, 500 ProcessSubcomputation(xla_while->while_body())); 501 502 TF_ASSIGN_OR_RETURN(const Properties condition_properties, 503 ProcessSubcomputation(xla_while->while_condition())); 504 505 current_properties_.clear(); 506 for (const auto& property : body_properties) { 507 current_properties_[property.first] += property.second; 508 } 509 for (const auto& property : condition_properties) { 510 current_properties_[property.first] += property.second; 511 } 512 current_should_compute_bottleneck_time_ = false; 513 514 return Status::OK(); 515 } 516 517 Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) { 518 // Compute the cost of the true and false computations and take the maximum 519 // from those for each property. 520 TF_ASSIGN_OR_RETURN(const Properties true_computation_properties, 521 ProcessSubcomputation(conditional->true_computation())); 522 TF_ASSIGN_OR_RETURN(const Properties false_computation_properties, 523 ProcessSubcomputation(conditional->false_computation())); 524 current_properties_ = true_computation_properties; 525 for (const auto& property : false_computation_properties) { 526 if (!tensorflow::gtl::InsertIfNotPresent(¤t_properties_, property)) { 527 current_properties_[property.first] = 528 std::max(current_properties_[property.first], property.second); 529 } 530 } 531 current_should_compute_bottleneck_time_ = false; 532 533 return Status::OK(); 534 } 535 536 Status HloCostAnalysis::HandleGather(const HloInstruction* gather) { 537 // Gather does not issue any flops. 538 return Status::OK(); 539 } 540 541 Status HloCostAnalysis::FinishVisit(const HloInstruction*) { 542 return Status::OK(); 543 } 544 545 float HloCostAnalysis::flop_count() const { 546 return GetProperty(kFlopsKey, properties_sum_); 547 } 548 549 float HloCostAnalysis::transcendental_count() const { 550 return GetProperty(kTranscendentalsKey, properties_sum_); 551 } 552 553 float HloCostAnalysis::bytes_accessed() const { 554 return GetProperty(kBytesAccessedKey, properties_sum_); 555 } 556 557 float HloCostAnalysis::optimal_seconds() const { 558 return GetProperty(kOptimalSecondsKey, properties_sum_); 559 } 560 561 int64 HloCostAnalysis::flop_count(const HloInstruction& hlo) const { 562 return GetPropertyForHlo(hlo, kFlopsKey, hlo_properties_); 563 } 564 565 int64 HloCostAnalysis::transcendental_count(const HloInstruction& hlo) const { 566 return GetPropertyForHlo(hlo, kTranscendentalsKey, hlo_properties_); 567 } 568 569 int64 HloCostAnalysis::bytes_accessed(const HloInstruction& hlo) const { 570 return GetPropertyForHlo(hlo, kBytesAccessedKey, hlo_properties_); 571 } 572 573 float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const { 574 return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_); 575 } 576 577 StatusOr<HloCostAnalysis::Properties> HloCostAnalysis::ProcessSubcomputation( 578 HloComputation* computation, const ShapeSizeFunction* shape_size) { 579 if (shape_size == nullptr) { 580 shape_size = &shape_size_; 581 } 582 HloCostAnalysis visitor(*shape_size, per_second_rates_); 583 TF_RETURN_IF_ERROR(computation->Accept(&visitor)); 584 return visitor.properties(); 585 } 586 587 } // namespace xla 588