1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" 17 18 #include <cmath> 19 20 #include "tensorflow/compiler/xla/shape_util.h" 21 #include "tensorflow/compiler/xla/status_macros.h" 22 #include "tensorflow/compiler/xla/util.h" 23 #include "tensorflow/compiler/xla/window_util.h" 24 #include "tensorflow/core/lib/core/bits.h" 25 #include "tensorflow/core/lib/core/errors.h" 26 #include "tensorflow/core/lib/gtl/map_util.h" 27 28 namespace xla { 29 30 constexpr char HloCostAnalysis::kFlopsKey[]; 31 constexpr char HloCostAnalysis::kTranscendentalsKey[]; 32 constexpr char HloCostAnalysis::kBytesAccessedKey[]; 33 constexpr char HloCostAnalysis::kOptimalSecondsKey[]; 34 35 HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size) 36 : HloCostAnalysis(shape_size, {}) {} 37 38 HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size, 39 const Properties& per_second_rates) 40 : shape_size_(shape_size), per_second_rates_(per_second_rates) {} 41 42 Status HloCostAnalysis::Preprocess(const HloInstruction* hlo) { 43 // Set current instruction cost values to reasonable default values. Each 44 // handler can overwrite these values. In Postprocess, these values are 45 // accumulated and written to the per-instruction maps. 46 current_properties_.clear(); 47 current_should_compute_bottleneck_time_ = true; 48 49 // The default number of bytes accessed for an instruction is the sum of the 50 // sizes of the inputs and outputs. The default ShapeUtil::ByteSizeOf does not 51 // handle opaque types. 52 float bytes_accessed = GetShapeSize(hlo->shape()); 53 for (const HloInstruction* operand : hlo->operands()) { 54 bytes_accessed += GetShapeSize(operand->shape()); 55 } 56 current_properties_[kBytesAccessedKey] = bytes_accessed; 57 58 return Status::OK(); 59 } 60 61 Status HloCostAnalysis::Postprocess(const HloInstruction* hlo) { 62 if (current_should_compute_bottleneck_time_) { 63 // Compute the time as the time of the bottleneck, i.e. the slowest property 64 // given the per-second rate of each property. 65 float optimal_seconds = 0.0f; 66 for (const auto& property : current_properties_) { 67 if (property.first != kOptimalSecondsKey) { 68 optimal_seconds = std::max( 69 optimal_seconds, 70 property.second / 71 GetProperty(property.first, per_second_rates_, INFINITY)); 72 } 73 } 74 current_properties_[kOptimalSecondsKey] = optimal_seconds; 75 } 76 77 TF_RET_CHECK(hlo_properties_.emplace(hlo, current_properties_).second); 78 for (const auto& property : current_properties_) { 79 properties_sum_[property.first] += property.second; 80 } 81 82 return Status::OK(); 83 } 84 85 Status HloCostAnalysis::HandleElementwiseOp( 86 const HloInstruction* hlo_instruction) { 87 const auto& shape = hlo_instruction->shape(); 88 // For element-wise operations, the number of computations is the same as the 89 // number of elements in the output shape. 90 auto computation_count = ShapeUtil::ElementsIn(shape); 91 auto opcode = hlo_instruction->opcode(); 92 // We treat transcendental operations separately since one transcendental 93 // operation can correspond to several floating point ops. 94 if (opcode == HloOpcode::kExp || opcode == HloOpcode::kLog || 95 opcode == HloOpcode::kPower || opcode == HloOpcode::kSqrt || 96 opcode == HloOpcode::kRsqrt || opcode == HloOpcode::kTanh || 97 opcode == HloOpcode::kSin || opcode == HloOpcode::kCos) { 98 current_properties_[kTranscendentalsKey] = computation_count; 99 } else { 100 // Note: transcendental operations are considered a separate category from 101 // FLOPs. 102 current_properties_[kFlopsKey] = computation_count; 103 } 104 return Status::OK(); 105 } 106 107 /*static*/ float HloCostAnalysis::GetProperty(const string& key, 108 const Properties& properties, 109 const float default_value) { 110 auto key_value = properties.find(key); 111 return key_value == properties.end() ? default_value : key_value->second; 112 } 113 114 /*static*/ float HloCostAnalysis::GetPropertyForHlo( 115 const HloInstruction& hlo, const string& key, 116 const HloToProperties& hlo_to_properties) { 117 auto it = hlo_to_properties.find(&hlo); 118 if (it == hlo_to_properties.end()) { 119 return 0.0f; 120 } else { 121 return GetProperty(key, it->second); 122 } 123 } 124 125 int64 HloCostAnalysis::GetShapeSize(const Shape& shape) const { 126 if (!LayoutUtil::HasLayout(shape)) { 127 return 0; 128 } 129 return shape_size_(shape); 130 } 131 132 Status HloCostAnalysis::HandleElementwiseUnary(const HloInstruction* hlo) { 133 return HandleElementwiseOp(hlo); 134 } 135 136 Status HloCostAnalysis::HandleElementwiseBinary(const HloInstruction* hlo) { 137 return HandleElementwiseOp(hlo); 138 } 139 140 Status HloCostAnalysis::HandleCompare(const HloInstruction* compare) { 141 return HandleElementwiseOp(compare); 142 } 143 144 Status HloCostAnalysis::HandleClamp(const HloInstruction* clamp) { 145 return HandleElementwiseOp(clamp); 146 } 147 148 Status HloCostAnalysis::HandleReducePrecision(const HloInstruction* hlo) { 149 return HandleElementwiseOp(hlo); 150 } 151 152 Status HloCostAnalysis::HandleParameter(const HloInstruction*) { 153 current_should_compute_bottleneck_time_ = false; 154 current_properties_[kBytesAccessedKey] = 0; 155 current_properties_[kOptimalSecondsKey] = 0; 156 return Status::OK(); 157 } 158 159 Status HloCostAnalysis::HandleConstant(const HloInstruction*) { 160 current_should_compute_bottleneck_time_ = false; 161 current_properties_[kBytesAccessedKey] = 0; 162 current_properties_[kOptimalSecondsKey] = 0; 163 return Status::OK(); 164 } 165 166 Status HloCostAnalysis::HandleIota(const HloInstruction*) { 167 return Status::OK(); 168 } 169 170 Status HloCostAnalysis::HandleGetTupleElement(const HloInstruction*) { 171 // GetTupleElement forwards a pointer and does not touch each element in the 172 // output. 173 current_should_compute_bottleneck_time_ = false; 174 current_properties_[kBytesAccessedKey] = 0; 175 current_properties_[kOptimalSecondsKey] = 0; 176 return Status::OK(); 177 } 178 179 Status HloCostAnalysis::HandleSelect(const HloInstruction* hlo) { 180 return HandleElementwiseOp(hlo); 181 } 182 183 Status HloCostAnalysis::HandleTupleSelect(const HloInstruction*) { 184 return Status::OK(); 185 } 186 187 Status HloCostAnalysis::HandleReverse(const HloInstruction*) { 188 return Status::OK(); 189 } 190 191 Status HloCostAnalysis::HandleSlice(const HloInstruction* slice) { 192 current_properties_[kBytesAccessedKey] = GetShapeSize(slice->shape()) * 2; 193 return Status::OK(); 194 } 195 196 Status HloCostAnalysis::HandleDynamicSlice( 197 const HloInstruction* dynamic_slice) { 198 current_properties_[kBytesAccessedKey] = 199 GetShapeSize(dynamic_slice->shape()) * 2; 200 return Status::OK(); 201 } 202 203 Status HloCostAnalysis::HandleDynamicUpdateSlice( 204 const HloInstruction* dynamic_update_slice) { 205 current_properties_[kBytesAccessedKey] = 206 GetShapeSize(dynamic_update_slice->operand(1)->shape()) * 2; 207 return Status::OK(); 208 } 209 210 Status HloCostAnalysis::HandleTuple(const HloInstruction* tuple) { 211 // The tuple instruction only gathers pointers from inputs (it doesn't iterate 212 // through them). The memory touched is then only the size of the output 213 // index table of the tuple. 214 215 current_properties_[kBytesAccessedKey] = GetShapeSize(tuple->shape()); 216 return Status::OK(); 217 } 218 219 Status HloCostAnalysis::HandleConcatenate(const HloInstruction*) { 220 return Status::OK(); 221 } 222 223 Status HloCostAnalysis::HandleConvert(const HloInstruction* convert) { 224 return HandleElementwiseOp(convert); 225 } 226 227 Status HloCostAnalysis::HandleCopy(const HloInstruction*) { 228 return Status::OK(); 229 } 230 231 Status HloCostAnalysis::HandleDomain(const HloInstruction* domain) { 232 // Domain does not have any computation or data transfer. 233 current_should_compute_bottleneck_time_ = false; 234 current_properties_[kBytesAccessedKey] = 0; 235 current_properties_[kOptimalSecondsKey] = 0; 236 return Status::OK(); 237 } 238 239 Status HloCostAnalysis::HandleDot(const HloInstruction* dot) { 240 const Shape& lhs_shape = dot->operand(0)->shape(); 241 const Shape& dot_shape = dot->shape(); 242 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); 243 // Count of elements along the reduction dimension (last dimension for the 244 // rhs). 245 int64 reduction_width = 1; 246 for (auto dim : dnums.lhs_contracting_dimensions()) { 247 reduction_width *= lhs_shape.dimensions(dim); 248 } 249 // Each output elment requires reduction_width FMA operations. 250 current_properties_[kFlopsKey] = 251 kFmaFlops * ShapeUtil::ElementsIn(dot_shape) * reduction_width; 252 return Status::OK(); 253 } 254 255 Status HloCostAnalysis::HandleInfeed(const HloInstruction*) { 256 return Status::OK(); 257 } 258 259 Status HloCostAnalysis::HandleOutfeed(const HloInstruction*) { 260 return Status::OK(); 261 } 262 263 Status HloCostAnalysis::HandleMap(const HloInstruction* map) { 264 // Compute properties of the mapped function. 265 TF_ASSIGN_OR_RETURN(const Properties sub_properties, 266 ProcessNestedSubcomputation(map->to_apply())); 267 268 // Compute the cost of all elements for this Map operation. 269 const int64 element_count = ShapeUtil::ElementsIn(map->shape()); 270 for (const auto& property : sub_properties) { 271 if (property.first != kBytesAccessedKey) { 272 current_properties_[property.first] = property.second * element_count; 273 } 274 } 275 return Status::OK(); 276 } 277 278 Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) { 279 HloComputation* function = reduce->to_apply(); 280 // Compute the cost of the user function. 281 TF_ASSIGN_OR_RETURN(const Properties sub_properties, 282 ProcessNestedSubcomputation(function)); 283 284 // Compute the cost of all elements for this Reduce operation. 285 // This counts the number of times the reduction function is applied, so it 286 // does not need to be multiplied by the number of input tensors - that's 287 // already "priced in" by the sub-computation doing more work. 288 auto arg = reduce->operand(0); 289 auto output_shape = reduce->shape().IsArray() 290 ? reduce->shape() 291 : reduce->shape().tuple_shapes(0); 292 int64 reduction_count = 293 ShapeUtil::ElementsIn(arg->shape()) - ShapeUtil::ElementsIn(output_shape); 294 for (const auto& property : sub_properties) { 295 if (property.first != kBytesAccessedKey) { 296 current_properties_[property.first] = property.second * reduction_count; 297 } 298 } 299 return Status::OK(); 300 } 301 302 Status HloCostAnalysis::HandleReduceWindow( 303 const HloInstruction* reduce_window) { 304 const Window& window = reduce_window->window(); 305 auto function = reduce_window->to_apply(); 306 // Compute the properties of the reduction function. 307 TF_ASSIGN_OR_RETURN(const Properties sub_properties, 308 ProcessNestedSubcomputation(function)); 309 310 // Compute the cost of all elements for this ReduceWindow operation. For each 311 // output element there are window_size - 1 reductions to perform. 312 int64 window_element_count = 1; 313 for (const auto& dimension : window.dimensions()) { 314 window_element_count *= dimension.size(); 315 } 316 const int64 output_element_count = 317 ShapeUtil::ElementsIn(reduce_window->shape()); 318 const int64 reduction_count = 319 (window_element_count - 1) * output_element_count; 320 for (const auto& property : sub_properties) { 321 if (property.first != kBytesAccessedKey) { 322 current_properties_[property.first] = property.second * reduction_count; 323 } 324 } 325 return Status::OK(); 326 } 327 328 Status HloCostAnalysis::HandleSelectAndScatter( 329 const HloInstruction* instruction) { 330 // Compute the properties of the select and scatter function. 331 // Compute the properties of the reduction function. 332 TF_ASSIGN_OR_RETURN(const Properties select_properties, 333 ProcessNestedSubcomputation(instruction->select())); 334 TF_ASSIGN_OR_RETURN(const Properties scatter_properties, 335 ProcessNestedSubcomputation(instruction->scatter())); 336 337 // Compute the cost of all elements for this operation. For each scatter 338 // source element there are window_size - 1 select computations to perform and 339 // 1 scatter computation to perform. 340 const auto source = instruction->operand(1); 341 const auto source_element_count = ShapeUtil::ElementsIn(source->shape()); 342 int64 window_element_count = 1; 343 for (const auto& dimension : instruction->window().dimensions()) { 344 window_element_count *= dimension.size(); 345 } 346 const int64 select_count = source_element_count * (window_element_count - 1); 347 for (const auto& property : select_properties) { 348 if (property.first != kBytesAccessedKey) { 349 current_properties_[property.first] += property.second * select_count; 350 } 351 } 352 for (const auto& property : scatter_properties) { 353 if (property.first != kBytesAccessedKey) { 354 current_properties_[property.first] += 355 property.second * source_element_count; 356 } 357 } 358 return Status::OK(); 359 } 360 361 Status HloCostAnalysis::HandleBitcast(const HloInstruction*) { 362 // A bitcast does no computation and touches no memory. 363 current_properties_[kBytesAccessedKey] = 0; 364 current_properties_[kOptimalSecondsKey] = 0; 365 return Status::OK(); 366 } 367 368 Status HloCostAnalysis::HandleBroadcast(const HloInstruction*) { 369 return Status::OK(); 370 } 371 372 Status HloCostAnalysis::HandlePad(const HloInstruction*) { 373 return Status::OK(); 374 } 375 376 Status HloCostAnalysis::HandleSend(const HloInstruction*) { 377 return Status::OK(); 378 } 379 380 Status HloCostAnalysis::HandleSendDone(const HloInstruction*) { 381 return Status::OK(); 382 } 383 384 Status HloCostAnalysis::HandleRecv(const HloInstruction*) { 385 return Status::OK(); 386 } 387 388 Status HloCostAnalysis::HandleRecvDone(const HloInstruction*) { 389 return Status::OK(); 390 } 391 392 Status HloCostAnalysis::HandleReshape(const HloInstruction*) { 393 return Status::OK(); 394 } 395 396 Status HloCostAnalysis::HandleBatchNormTraining(const HloInstruction*) { 397 // TODO(b/62294698): Implement cost analysis for batch-norm-training. 398 return Status::OK(); 399 } 400 401 Status HloCostAnalysis::HandleBatchNormInference(const HloInstruction*) { 402 // TODO(b/62294698): Implement cost analysis for batch-norm-inference. 403 return Status::OK(); 404 } 405 406 Status HloCostAnalysis::HandleBatchNormGrad(const HloInstruction*) { 407 // TODO(b/62294698): Implement cost analysis for batch-norm-grad. 408 return Status::OK(); 409 } 410 411 Status HloCostAnalysis::HandleTranspose(const HloInstruction*) { 412 return Status::OK(); 413 } 414 415 Status HloCostAnalysis::HandleAfterAll(const HloInstruction*) { 416 // This instruction is used to enforce ordering at compile time. No code is 417 // emitted. 418 current_should_compute_bottleneck_time_ = false; 419 current_properties_[kBytesAccessedKey] = 0; 420 current_properties_[kOptimalSecondsKey] = 0; 421 return Status::OK(); 422 } 423 424 Status HloCostAnalysis::HandleAddDependency( 425 const HloInstruction* add_dependency) { 426 // This instruction is used to enforce ordering at compile time. No code is 427 // emitted. 428 current_should_compute_bottleneck_time_ = false; 429 current_properties_[kBytesAccessedKey] = 0; 430 current_properties_[kOptimalSecondsKey] = 0; 431 return Status::OK(); 432 } 433 434 Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) { 435 auto lhs = convolution->operand(0); 436 auto rhs = convolution->operand(1); 437 Window window = convolution->window(); 438 const auto& result_shape = convolution->shape(); 439 const Shape& lhs_shape = lhs->shape(); 440 const Shape& rhs_shape = rhs->shape(); 441 442 const auto& dnums = convolution->convolution_dimension_numbers(); 443 444 const int64 input_batch_dim = dnums.input_batch_dimension(); 445 const int64 input_feature_dim = dnums.input_feature_dimension(); 446 const int64 output_feature_dim = dnums.output_feature_dimension(); 447 const int64 input_feature = 448 ShapeUtil::GetDimension(lhs_shape, input_feature_dim); 449 const int64 output_feature = 450 ShapeUtil::GetDimension(result_shape, output_feature_dim); 451 const int64 batch = ShapeUtil::GetDimension(lhs_shape, input_batch_dim); 452 453 DimensionVector kernel_limits; 454 DimensionVector output_limits; 455 DimensionVector input_limits; 456 if (window.dimensions().empty()) { 457 window = window_util::MakeWindow({1}); 458 kernel_limits.push_back(1); 459 output_limits.push_back(1); 460 input_limits.push_back(1); 461 } else { 462 for (int64 spatial_dimension = 0; 463 spatial_dimension < window.dimensions_size(); ++spatial_dimension) { 464 // Spatial dimension number for kernel (rhs). 465 const int64 kernel_spatial_dim = 466 dnums.kernel_spatial_dimensions(spatial_dimension); 467 const int64 kernel_limit = rhs_shape.dimensions(kernel_spatial_dim); 468 kernel_limits.push_back(kernel_limit); 469 470 // Spatial dimension number for output. 471 const int64 output_spatial_dim = 472 dnums.output_spatial_dimensions(spatial_dimension); 473 const int64 output_limit = result_shape.dimensions(output_spatial_dim); 474 output_limits.push_back(output_limit); 475 476 // Spatial dimension number for input (lhs). 477 const int64 input_spatial_dim = 478 dnums.input_spatial_dimensions(spatial_dimension); 479 const int64 input_limit = lhs_shape.dimensions(input_spatial_dim); 480 input_limits.push_back(input_limit); 481 } 482 } 483 484 DimensionVector valid_position_counts; 485 486 // Loop over each spatial dimension. 487 for (int64 spatial_dimension = 0; 488 spatial_dimension < window.dimensions_size(); ++spatial_dimension) { 489 int64 valid_position_count = 0; 490 // Loop over each point in the kernel. 491 for (int64 kernel_idx = 0; kernel_idx < kernel_limits[spatial_dimension]; 492 ++kernel_idx) { 493 // Loop over each point in the output. 494 for (int64 output_idx = 0; output_idx < output_limits[spatial_dimension]; 495 ++output_idx) { 496 // Calculate lhs (input) index without taking base dilation into 497 // account. 498 const auto& window_dim = window.dimensions(spatial_dimension); 499 const int64 undilated_index = output_idx * window_dim.stride() - 500 window_dim.padding_low() + 501 kernel_idx * window_dim.window_dilation(); 502 503 // Calculate the actual lhs (input) index after dilation. Avoid the 504 // division as an optimization. 505 const int64 lhs_spatial_index = 506 window_dim.base_dilation() > 1 507 ? undilated_index / window_dim.base_dilation() 508 : undilated_index; 509 510 // Skip if the lhs (input) index is to be dilated. 511 if (undilated_index != lhs_spatial_index * window_dim.base_dilation()) { 512 continue; 513 } 514 515 // Skip if input index is not in bound. 516 if (lhs_spatial_index < 0 || 517 lhs_spatial_index >= input_limits[spatial_dimension]) { 518 continue; 519 } 520 521 valid_position_count += 1; 522 } 523 } 524 valid_position_counts.push_back(valid_position_count); 525 } 526 527 const int64 fma_count = (input_feature / convolution->feature_group_count()) * 528 output_feature * 529 (batch / convolution->batch_group_count()) * 530 Product(valid_position_counts); 531 current_properties_[kFlopsKey] = fma_count * kFmaFlops; 532 return Status::OK(); 533 } 534 535 Status HloCostAnalysis::HandleFft(const HloInstruction* fft) { 536 auto real_shape = 537 fft->operand(0)->shape().IsTuple() 538 ? ShapeUtil::GetTupleElementShape(fft->operand(0)->shape(), 0) 539 : fft->operand(0)->shape(); 540 constexpr int kFmaPerComplexMul = 4; 541 int64 log_factors = 1; 542 for (int64 dim : fft->fft_length()) { 543 log_factors *= tensorflow::Log2Floor(dim); 544 } 545 current_properties_[kFlopsKey] = kFmaFlops * kFmaPerComplexMul * log_factors * 546 ShapeUtil::ElementsIn(real_shape); 547 return Status::OK(); 548 } 549 550 Status HloCostAnalysis::HandleTriangularSolve(const HloInstruction* hlo) { 551 float bytes_accessed = GetShapeSize(hlo->operand(0)->shape()) / 2.0f; 552 bytes_accessed += GetShapeSize(hlo->operand(1)->shape()); 553 current_properties_[kBytesAccessedKey] = bytes_accessed; 554 555 const Shape& a_shape = hlo->operand(0)->shape(); 556 const Shape& b_shape = hlo->operand(1)->shape(); 557 // Estimate as batch * mn^2 / 2 flops. 558 int64 elems = a_shape.dimensions(a_shape.dimensions_size() - 1); 559 elems *= ShapeUtil::ElementsIn(b_shape); 560 current_properties_[kFlopsKey] = kFmaFlops * elems; 561 return Status::OK(); 562 } 563 564 Status HloCostAnalysis::HandleCholesky(const HloInstruction* hlo) { 565 float bytes_accessed = GetShapeSize(hlo->operand(0)->shape()) / 2.0f; 566 current_properties_[kBytesAccessedKey] = bytes_accessed; 567 568 const Shape& a_shape = hlo->operand(0)->shape(); 569 // Estimate as batch * n^3 / 3 flops. 570 int64 elems = a_shape.dimensions(a_shape.dimensions_size() - 1); 571 elems *= ShapeUtil::ElementsIn(a_shape); 572 current_properties_[kFlopsKey] = elems / 3; 573 return Status::OK(); 574 } 575 576 Status HloCostAnalysis::HandleAllReduce(const HloInstruction* crs) { 577 // We assume 2 replicas, so that each output element is the sum of two input 578 // elements. 579 // 580 // TODO(b/33004697): Compute correct cost here, taking the actual number of 581 // replicas into account. 582 double flops = 0.0; 583 ShapeUtil::ForEachSubshape(crs->shape(), 584 [&](const Shape& subshape, const ShapeIndex&) { 585 if (subshape.IsArray()) { 586 flops += ShapeUtil::ElementsIn(subshape); 587 } 588 }); 589 current_properties_[kFlopsKey] = flops; 590 return Status::OK(); 591 } 592 593 Status HloCostAnalysis::HandleAllToAll(const HloInstruction* hlo) { 594 return Status::OK(); 595 } 596 597 Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) { 598 return Status::OK(); 599 } 600 601 Status HloCostAnalysis::HandleReplicaId(const HloInstruction* /*hlo*/) { 602 return Status::OK(); 603 } 604 605 Status HloCostAnalysis::HandleRng(const HloInstruction* random) { 606 // TODO(b/26346211): Implement better estimates for the RNG cost, since the 607 // cost changes with the implementation and the distribution. For now, assume 608 // the cost of each RNG is same as a transcendental operation. 609 current_properties_[kTranscendentalsKey] = 610 ShapeUtil::ElementsIn(random->shape()); 611 return Status::OK(); 612 } 613 614 Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { 615 TF_ASSIGN_OR_RETURN( 616 current_properties_, 617 ProcessNestedSubcomputation(fusion->fused_instructions_computation())); 618 619 // Fusion nodes that produce a tuple also produce the entries in the tuple. 620 // Ignore the memory accessed inside fused ops, since fusion is supposed to 621 // prevent intermediate data from touching slow memory. 622 current_properties_[kBytesAccessedKey] = 0; 623 ShapeUtil::ForEachSubshape( 624 fusion->shape(), 625 [this](const Shape& subshape, const ShapeIndex& /*shape_index*/) { 626 current_properties_[kBytesAccessedKey] += GetShapeSize(subshape); 627 }); 628 629 for (const HloInstruction* operand : fusion->operands()) { 630 current_properties_[kBytesAccessedKey] += GetShapeSize(operand->shape()); 631 } 632 633 return Status::OK(); 634 } 635 636 Status HloCostAnalysis::HandleCall(const HloInstruction* call) { 637 TF_ASSIGN_OR_RETURN(current_properties_, 638 ProcessUnnestedSubcomputation(call->to_apply())); 639 current_should_compute_bottleneck_time_ = false; 640 return Status::OK(); 641 } 642 643 Status HloCostAnalysis::HandleCustomCall(const HloInstruction*) { 644 // Mark applicable fields as "unknown", since we don't know what CustomCall 645 // does. This is better than returning an error, which would stop iteration, 646 // and therefore would prevent us from getting *any* stats for a computation 647 // which contains a CustomCall. 648 current_properties_[kOptimalSecondsKey] = -1; 649 current_properties_[kBytesAccessedKey] = -1; 650 current_properties_[kFlopsKey] = -1; 651 current_should_compute_bottleneck_time_ = false; 652 return Status::OK(); 653 } 654 655 Status HloCostAnalysis::HandleSort(const HloInstruction* sort) { 656 // This assumes a comparison based N*log(N) algorithm. As for all ops, the 657 // actual properties of the op depend on the backend implementation. 658 int64 elements = ShapeUtil::ElementsIn(sort->operand(0)->shape()); 659 current_properties_[kFlopsKey] = elements * tensorflow::Log2Ceiling(elements); 660 return Status::OK(); 661 } 662 663 Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) { 664 // Since the number of iterations of the while node will not always be 665 // something that we can statically analyze, we cannot precisely compute the 666 // cost of a while node. For now compute the cost of a single iteration. 667 TF_ASSIGN_OR_RETURN(const Properties body_properties, 668 ProcessUnnestedSubcomputation(xla_while->while_body())); 669 670 TF_ASSIGN_OR_RETURN( 671 const Properties condition_properties, 672 ProcessUnnestedSubcomputation(xla_while->while_condition())); 673 674 current_properties_.clear(); 675 for (const auto& property : body_properties) { 676 current_properties_[property.first] += property.second; 677 } 678 for (const auto& property : condition_properties) { 679 current_properties_[property.first] += property.second; 680 } 681 current_should_compute_bottleneck_time_ = false; 682 683 return Status::OK(); 684 } 685 686 Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) { 687 // Compute the cost of the branch computations and take the maximum from those 688 // for each property. 689 TF_ASSIGN_OR_RETURN( 690 const Properties branch0_computation_properties, 691 ProcessUnnestedSubcomputation(conditional->branch_computation(0))); 692 current_properties_ = branch0_computation_properties; 693 for (int j = 1; j < conditional->branch_count(); ++j) { 694 TF_ASSIGN_OR_RETURN( 695 const Properties branch_computation_properties, 696 ProcessUnnestedSubcomputation(conditional->branch_computation(j))); 697 for (const auto& property : branch_computation_properties) { 698 if (!tensorflow::gtl::InsertIfNotPresent(¤t_properties_, 699 property)) { 700 auto& current_property = current_properties_[property.first]; 701 current_property = std::max(current_property, property.second); 702 } 703 } 704 } 705 current_should_compute_bottleneck_time_ = false; 706 707 return Status::OK(); 708 } 709 710 Status HloCostAnalysis::HandleGather(const HloInstruction* gather) { 711 // Gather doesn't read the whole input buffer, it's equivalent to a copy the 712 // size of the output shape and a read of the gather indices. 713 current_properties_[kBytesAccessedKey] = 714 GetShapeSize(gather->shape()) * 2 + 715 GetShapeSize(gather->operand(1)->shape()); 716 // Gather does not issue any flops. 717 return Status::OK(); 718 } 719 720 Status HloCostAnalysis::HandleScatter(const HloInstruction* scatter) { 721 current_properties_[kBytesAccessedKey] = 722 GetShapeSize(scatter->operand(2)->shape()) * 2 + 723 GetShapeSize(scatter->operand(1)->shape()); 724 const int64 element_count = 725 ShapeUtil::ElementsIn(scatter->operand(2)->shape()); 726 TF_ASSIGN_OR_RETURN(const Properties sub_properties, 727 ProcessNestedSubcomputation(scatter->to_apply())); 728 for (const auto& property : sub_properties) { 729 if (property.first != kBytesAccessedKey) { 730 current_properties_[property.first] = property.second * element_count; 731 } 732 } 733 return Status::OK(); 734 } 735 736 Status HloCostAnalysis::HandleGetDimensionSize( 737 const HloInstruction* /*get_size*/) { 738 return Status::OK(); 739 } 740 741 Status HloCostAnalysis::FinishVisit(const HloInstruction*) { 742 return Status::OK(); 743 } 744 745 float HloCostAnalysis::flop_count() const { 746 return GetProperty(kFlopsKey, properties_sum_); 747 } 748 749 float HloCostAnalysis::transcendental_count() const { 750 return GetProperty(kTranscendentalsKey, properties_sum_); 751 } 752 753 float HloCostAnalysis::bytes_accessed() const { 754 return GetProperty(kBytesAccessedKey, properties_sum_); 755 } 756 757 float HloCostAnalysis::optimal_seconds() const { 758 return GetProperty(kOptimalSecondsKey, properties_sum_); 759 } 760 761 int64 HloCostAnalysis::flop_count(const HloInstruction& hlo) const { 762 return GetPropertyForHlo(hlo, kFlopsKey, hlo_properties_); 763 } 764 765 int64 HloCostAnalysis::transcendental_count(const HloInstruction& hlo) const { 766 return GetPropertyForHlo(hlo, kTranscendentalsKey, hlo_properties_); 767 } 768 769 int64 HloCostAnalysis::bytes_accessed(const HloInstruction& hlo) const { 770 return GetPropertyForHlo(hlo, kBytesAccessedKey, hlo_properties_); 771 } 772 773 float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const { 774 return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_); 775 } 776 777 StatusOr<HloCostAnalysis::Properties> 778 HloCostAnalysis::ProcessNestedSubcomputation(HloComputation* computation) { 779 HloCostAnalysis visitor(shape_size_, per_second_rates_); 780 TF_RETURN_IF_ERROR(computation->Accept(&visitor)); 781 return visitor.properties(); 782 } 783 784 StatusOr<HloCostAnalysis::Properties> 785 HloCostAnalysis::ProcessUnnestedSubcomputation(HloComputation* computation) { 786 HloCostAnalysis visitor(shape_size_, per_second_rates_); 787 TF_RETURN_IF_ERROR(computation->Accept(&visitor)); 788 hlo_properties_.insert(visitor.hlo_properties_.begin(), 789 visitor.hlo_properties_.end()); 790 return visitor.properties(); 791 } 792 793 } // namespace xla 794