1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" 17 18 #include "absl/algorithm/container.h" 19 #include "absl/container/flat_hash_map.h" 20 #include "absl/container/flat_hash_set.h" 21 #include "absl/container/inlined_vector.h" 22 #include "absl/strings/str_cat.h" 23 #include "absl/strings/str_join.h" 24 #include "absl/types/optional.h" 25 #include "tensorflow/compiler/xla/map_util.h" 26 #include "tensorflow/compiler/xla/service/hlo_evaluator.h" 27 #include "tensorflow/compiler/xla/util.h" 28 29 namespace xla { 30 31 namespace { 32 using Analysis = IndexedArrayAnalysis; 33 using UnknownArray = Analysis::UnknownArray; 34 using ConstantArray = Analysis::ConstantArray; 35 using ReshapedArray = Analysis::ReshapedArray; 36 using ScalarIndexedArray = Analysis::ScalarIndexedArray; 37 using absl::StrJoin; 38 } // namespace 39 40 string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) { 41 switch (root->kind()) { 42 case Array::kUnknown: { 43 auto* unknown_tensor = root->as<UnknownArray>(); 44 return absl::StrCat("%", unknown_tensor->instruction().name()); 45 } 46 47 case Array::kConstant: { 48 if (print_constants) { 49 string contents = root->as<ConstantArray>()->literal()->ToString(); 50 return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()), 51 " ", contents, ")"); 52 } 53 return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()), 54 ")"); 55 } 56 57 case Array::kReshaped: { 58 ReshapedArray* reshaped_array = root->as<ReshapedArray>(); 59 return absl::StrCat( 60 "(reshape ", ToString(reshaped_array->operand(), print_constants), 61 " to ", ShapeUtil::HumanString(reshaped_array->shape()), ")"); 62 } 63 64 case Array::kScalarIndexedConstant: 65 case Array::kScalarIndexed: { 66 auto* indexed_array = root->as<ScalarIndexedArray>(); 67 string name = root->kind() == Array::kScalarIndexedConstant 68 ? "scalar-indexed-const" 69 : "scalar-indexed"; 70 return absl::StrCat( 71 "(", name, " ", ToString(indexed_array->source(), print_constants), 72 " ", ToString(indexed_array->indices(), print_constants), " ", 73 indexed_array->source_dim(), "->[", 74 StrJoin(indexed_array->output_dims(), ","), "])"); 75 } 76 } 77 } 78 79 StatusOr<Analysis::Array*> IndexedArrayAnalysis::GetArrayFor( 80 const HloInstruction* instr) { 81 auto it = cache_.find(instr); 82 if (it != cache_.end()) { 83 return it->second; 84 } 85 86 TF_RETURN_IF_ERROR(TraverseAndPopulateCache(instr)); 87 return FindOrDie(cache_, instr); 88 } 89 90 Status IndexedArrayAnalysis::TraverseAndPopulateCache( 91 const HloInstruction* root) { 92 // Depth first search over the DAG, invoking ComputeArrayFor in post order. 93 // The HLO instructions already in the cache are considered leaves. 94 95 absl::InlinedVector<const HloInstruction*, 4> stack; 96 97 enum DfsState { kDiscovered, kVisited }; 98 absl::flat_hash_map<const HloInstruction*, DfsState> dfs_state_map; 99 100 stack.push_back(root); 101 InsertOrDie(&dfs_state_map, root, kDiscovered); 102 103 do { 104 const HloInstruction* instr = stack.back(); 105 if (cache_.contains(instr)) { 106 stack.pop_back(); 107 continue; 108 } 109 110 switch (FindOrDie(dfs_state_map, instr)) { 111 case kDiscovered: { 112 for (const HloInstruction* operand : instr->operands()) { 113 if (!cache_.contains(operand)) { 114 stack.push_back(operand); 115 CHECK(!dfs_state_map.contains(operand) || 116 dfs_state_map[operand] == kDiscovered); 117 dfs_state_map[operand] = kDiscovered; 118 } 119 } 120 dfs_state_map[instr] = kVisited; 121 break; 122 } 123 124 case kVisited: 125 stack.pop_back(); 126 TF_ASSIGN_OR_RETURN(Array * array, ComputeArrayFor(instr)); 127 InsertOrDie(&cache_, instr, array); 128 break; 129 } 130 } while (!stack.empty()); 131 132 return Status::OK(); 133 } 134 135 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayFor( 136 const HloInstruction* instr) { 137 Array* computed_array; 138 if (instr->IsElementwise() && instr->operand_count() == 1) { 139 TF_ASSIGN_OR_RETURN( 140 computed_array, 141 ComputeArrayForElementwiseUnaryOp( 142 instr->opcode(), FindOrDie(cache_, instr->operand(0)))); 143 } else if (instr->IsElementwise() && instr->operand_count() == 2) { 144 TF_ASSIGN_OR_RETURN( 145 computed_array, 146 ComputeArrayForElementwiseBinaryOp( 147 instr->opcode(), FindOrDie(cache_, instr->operand(0)), 148 FindOrDie(cache_, instr->operand(1)))); 149 } else if (instr->opcode() == HloOpcode::kConstant) { 150 TF_ASSIGN_OR_RETURN(computed_array, 151 ComputeArrayForConstant(instr->literal())); 152 } else if (instr->opcode() == HloOpcode::kGather) { 153 TF_ASSIGN_OR_RETURN( 154 computed_array, 155 ComputeArrayForGather(instr->shape(), instr->gather_dimension_numbers(), 156 instr->gather_slice_sizes(), 157 FindOrDie(cache_, instr->operand(0)), 158 FindOrDie(cache_, instr->operand(1)))); 159 } else if (instr->opcode() == HloOpcode::kReshape) { 160 TF_ASSIGN_OR_RETURN( 161 computed_array, 162 ComputeArrayForReshape(instr->shape(), 163 FindOrDie(cache_, instr->operand(0)))); 164 } else if (instr->opcode() == HloOpcode::kDot) { 165 TF_ASSIGN_OR_RETURN( 166 computed_array, 167 ComputeArrayForDot(instr->shape(), instr->dot_dimension_numbers(), 168 instr->precision_config(), 169 FindOrDie(cache_, instr->operand(0)), 170 FindOrDie(cache_, instr->operand(1)))); 171 } else { 172 computed_array = nullptr; 173 } 174 175 if (!computed_array) { 176 computed_array = Construct<UnknownArray>(instr); 177 } 178 179 return computed_array; 180 } 181 182 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForConstant( 183 const Literal& literal) { 184 return Construct<ConstantArray>(&literal); 185 } 186 187 StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::FoldGatherOfGather( 188 ScalarIndexedArray* source, Array* indices, int64 source_dim, 189 absl::Span<const int64> output_dims, Shape shape) { 190 // We want to transform Gather(Gather(A, X), Y) => Gather(A, Gather(X, Y)). 191 // `source` is the inner Gather(A, X). 192 193 Array* a = source->source(); 194 Array* x = source->indices(); 195 Array* y = indices; 196 197 // This bit is slightly tricky, so we do a naive "simulation" of the two 198 // consecutive gather operations to infer what the composed gather should look 199 // like. 200 201 enum class IndexComponent { Ungathered, GatheredFirst, GatheredSecond }; 202 203 std::vector<IndexComponent> simulated_index(a->shape().dimensions_size(), 204 IndexComponent::Ungathered); 205 206 // Simulate the first gather. 207 EraseAt(&simulated_index, source->source_dim()); 208 for (int64 gather_dim : source->output_dims()) { 209 simulated_index.insert(simulated_index.begin() + gather_dim, 210 IndexComponent::GatheredFirst); 211 } 212 213 // Simulate the second gather. 214 EraseAt(&simulated_index, source_dim); 215 for (int64 output_dim : output_dims) { 216 simulated_index.insert(simulated_index.begin() + output_dim, 217 IndexComponent::GatheredSecond); 218 } 219 220 int64 source_dim_for_index_array = 221 FindIndex(source->output_dims(), source_dim); 222 CHECK_NE(source_dim_for_index_array, source->output_dims().size()); 223 224 std::vector<int64> output_dims_for_index_array; 225 int64 gathered_index_components_seen = 0; 226 for (IndexComponent simulation_dim : simulated_index) { 227 if (simulation_dim == IndexComponent::GatheredSecond) { 228 output_dims_for_index_array.push_back(gathered_index_components_seen); 229 } 230 if (simulation_dim != IndexComponent::Ungathered) { 231 gathered_index_components_seen++; 232 } 233 } 234 235 std::vector<int64> dim_sizes_for_composed_index; 236 std::vector<int64> output_dims_for_new_gather; 237 for (int64 i = 0, e = simulated_index.size(); i < e; i++) { 238 if (simulated_index[i] != IndexComponent::Ungathered) { 239 dim_sizes_for_composed_index.push_back(shape.dimensions(i)); 240 output_dims_for_new_gather.push_back(i); 241 } 242 } 243 244 Array* inner_indices = ConstructScalarIndexedArray( 245 x, y, source_dim_for_index_array, output_dims_for_index_array, 246 ShapeUtil::MakeShape(x->shape().element_type(), 247 dim_sizes_for_composed_index)); 248 return ConstructScalarIndexedArray(a, inner_indices, source->source_dim(), 249 output_dims_for_new_gather, 250 std::move(shape)); 251 } 252 253 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForGather( 254 const Shape& shape, const GatherDimensionNumbers& dim_numbers, 255 absl::Span<const int64> slice_sizes, Array* source, Array* indices) { 256 if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) { 257 VLOG(3) << "ComputeArrayForGather: indices are not scalar"; 258 return nullptr; 259 } 260 261 CHECK_EQ(dim_numbers.start_index_map_size(), 1); 262 263 // We can also handle dim_numbers.collapsed_slice_dims_size() == 0 here, 264 // should it become relevant. 265 266 if (dim_numbers.collapsed_slice_dims_size() != 1 || 267 dim_numbers.collapsed_slice_dims(0) != dim_numbers.start_index_map(0)) { 268 VLOG(3) << "ComputeArrayForGather: gather operations must elide " 269 "start_index_map[0] and " 270 "start_index_map[0] only"; 271 return nullptr; 272 } 273 274 // ScalarIndexedArray cannot represent gathers that "slice" along some 275 // dimensions -- for instance it cannot represent a gather that picks 5 [2,3] 276 // arrays from an array of size [7,4,6]. We check that condition down below: 277 278 for (int64 i = 0, e = source->shape().dimensions_size(); i < e; i++) { 279 if (i != dim_numbers.collapsed_slice_dims(0) && 280 source->shape().dimensions(i) != slice_sizes[i]) { 281 VLOG(3) << "ComputeArrayForGather: slice_sizes[" << i 282 << "] != source->shape().dimensions(" << i << ") -- " 283 << source->shape().dimensions(i) << " vs. " << slice_sizes[i] 284 << " with dim_numbers.collapsed_slice_dims(0) = " 285 << dim_numbers.collapsed_slice_dims(0); 286 return nullptr; 287 } 288 } 289 290 int64 source_dim = dim_numbers.start_index_map(0); 291 std::vector<int64> output_dims; 292 for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { 293 if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) { 294 output_dims.push_back(i); 295 } 296 } 297 298 if (auto* indexed = dynamic_cast<ScalarIndexedArray*>(source)) { 299 if (absl::c_linear_search(indexed->output_dims(), source_dim)) { 300 return FoldGatherOfGather(indexed, indices, source_dim, output_dims, 301 shape); 302 } 303 } else if (auto* constant = dynamic_cast<ConstantArray*>(source)) { 304 return Construct<ScalarIndexedConstantArray>(constant, indices, source_dim, 305 output_dims, shape); 306 } 307 308 return Construct<ScalarIndexedArray>(source, indices, source_dim, output_dims, 309 shape); 310 } 311 312 namespace { 313 // Returns an index into `values` such that the product of the range 314 // [values.begin()+index, values.end()) is equal to `product`. If there is no 315 // such index, return -1. All integers in `values` must be positive. 316 int64 FindSuffixWithProduct(absl::Span<const int64> values, int64 product) { 317 DCHECK(absl::c_all_of(values, [](int64 value) { return value > 0; })); 318 319 int64 current_product = 1; 320 int64 i; 321 for (i = values.size() - 1; i >= 0 && product > current_product; --i) { 322 current_product *= values[i]; 323 } 324 325 if (product == current_product) { 326 return i + 1; 327 } 328 329 return -1; 330 } 331 332 struct ReshapePassthroughDimPair { 333 int64 result_dim; 334 int64 operand_dim; 335 }; 336 337 // Returns a set of dimension pairs such for all (result_dim, operand_dim) in 338 // the set: 339 // 340 // output_index[result_dim] = SourceIndexOfReshape(output_index)[operand_dim] 341 // 342 // The returned vector of pairs is sorted in both the result_dim and the 343 // operand_dim components. 344 std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs( 345 absl::Span<const int64> operand_shape, 346 absl::Span<const int64> result_shape) { 347 // A reshape can be seen as an index mapping from output index to input index: 348 // 349 // (i_0, ..., i_n) = f(o_0, ..., o_m) 350 // 351 // This function returns the pairs (j, k) for which the following invariant 352 // holds for all indices in the shape: 353 // 354 // o_j == i_k 355 // 356 // And this occurs when: 357 // 358 // O_{j+1} * ... * O_n == I_{k+1} * ... * I_m 359 // 360 // (where O_x are the sizes of the output shape and I_x are the sizes of the 361 // input shape) and the size of the dimension j of the result is the same as 362 // the size of dimension k in the operand. 363 // 364 // These conditions are sufficient because the Reshape HLO is spec'ed such 365 // that the rightmost dimensions are always minor in the flattening and refine 366 // operation. 367 368 std::vector<ReshapePassthroughDimPair> result; 369 int64 result_subarray_size = 1; 370 for (int64 result_dim = result_shape.size() - 1; result_dim >= 0; 371 --result_dim) { 372 int64 candidate_operand_dim = 373 FindSuffixWithProduct(operand_shape, result_subarray_size); 374 375 // result_subarray_size does not include the elements in the current 376 // `result_dim` dimension (we multiply in result_shape[result_dim] at the 377 // end of loop body) so candidate_operand_dim can never be zero. 378 CHECK_NE(candidate_operand_dim, 0) 379 << "result_dim = " << result_dim 380 << ", result_subarray_size = " << result_subarray_size 381 << ", result_shape = [" << StrJoin(result_shape, ",") << "]" 382 << ", operand_shape = [" << StrJoin(operand_shape, ",") << "]"; 383 384 if (candidate_operand_dim != -1 && 385 result_shape[result_dim] == operand_shape[candidate_operand_dim - 1]) { 386 result.push_back({/*result_dim=*/result_dim, 387 /*operand_dim=*/candidate_operand_dim - 1}); 388 } 389 result_subarray_size *= result_shape[result_dim]; 390 } 391 392 absl::c_reverse(result); 393 394 if (VLOG_IS_ON(3)) { 395 std::vector<string> result_strings; 396 absl::c_transform(result, std::back_inserter(result_strings), 397 [](ReshapePassthroughDimPair value) { 398 return absl::StrCat(value.result_dim, "->", 399 value.operand_dim); 400 }); 401 VLOG(3) << "For a reshape from [" << StrJoin(operand_shape, ",") << "] to [" 402 << StrJoin(result_shape, ",") << "] passthrough indices are [" 403 << StrJoin(result_strings, ",") 404 << "] (legend: `result`->`operand`)"; 405 } 406 407 DCHECK(absl::c_is_sorted( 408 result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) { 409 return lhs.result_dim < rhs.result_dim; 410 })); 411 412 DCHECK(absl::c_is_sorted( 413 result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) { 414 return lhs.operand_dim < rhs.operand_dim; 415 })); 416 417 return result; 418 } 419 420 // Return true if `dim` is stated as an passthrough operand dim in 421 // `passthrough_dims`. 422 bool IsReshapePassthroughOperandDim( 423 absl::Span<const ReshapePassthroughDimPair> passthrough_dims, int64 dim) { 424 return absl::c_any_of(passthrough_dims, 425 [&](ReshapePassthroughDimPair passthrough_dim_pair) { 426 return passthrough_dim_pair.operand_dim == dim; 427 }); 428 } 429 430 // Maps `operand_dim` which must be an passthrough operand dimension to its 431 // corresponding passthrough result dimension based on `passthrough_dims`. 432 int64 MapPassthroughOperandDimToResultDim( 433 absl::Span<const ReshapePassthroughDimPair> passthrough_dims, 434 int64 operand_dim) { 435 auto it = absl::c_find_if( 436 passthrough_dims, [&](ReshapePassthroughDimPair passthrough_dim_pair) { 437 return passthrough_dim_pair.operand_dim == operand_dim; 438 }); 439 CHECK(it != passthrough_dims.end()); 440 return it->result_dim; 441 } 442 443 int64 FindSourcePositionForPassthroughResultDim( 444 absl::Span<const int64> operand_shape, absl::Span<const int64> result_shape, 445 int64 source_passthrough_dim) { 446 VLOG(3) << "FindSourcePositionForPassthroughResultDim([" 447 << StrJoin(operand_shape, ",") << "], [" << StrJoin(result_shape, ",") 448 << "], " << source_passthrough_dim << ")"; 449 450 int64 indexed_source_subarray_size = 451 std::accumulate(operand_shape.begin() + source_passthrough_dim + 1, 452 operand_shape.end(), 1LL, std::multiplies<int64>()); 453 454 return FindSuffixWithProduct(result_shape, indexed_source_subarray_size); 455 } 456 457 Shape StripDegenerateDimensions(const Shape& shape) { 458 DimensionVector new_dims; 459 absl::c_copy_if(shape.dimensions(), std::back_inserter(new_dims), 460 [](int64 dim) { return dim != 1; }); 461 return ShapeUtil::MakeShape(shape.element_type(), new_dims); 462 } 463 }; // namespace 464 465 StatusOr<ScalarIndexedArray*> 466 IndexedArrayAnalysis::ReshapeToRemoveDegenerateDims( 467 ScalarIndexedArray* operand) { 468 const Shape& shape = operand->shape(); 469 if (!ShapeUtil::HasDegenerateDimensions(shape)) { 470 return operand; 471 } 472 473 // We only need to reshape out the degenerate dims from the indices and the 474 // source (except the source dim). 475 476 const Shape& source_shape = operand->source()->shape(); 477 DimensionVector new_source_shape_dims; 478 for (int64 i = 0, e = source_shape.dimensions_size(); i < e; i++) { 479 if (i == operand->source_dim() || source_shape.dimensions(i) != 1) { 480 new_source_shape_dims.push_back(source_shape.dimensions(i)); 481 } 482 } 483 484 Shape new_source_shape = 485 ShapeUtil::MakeShape(shape.element_type(), new_source_shape_dims); 486 Shape new_indices_shape = 487 StripDegenerateDimensions(operand->indices()->shape()); 488 489 TF_ASSIGN_OR_RETURN( 490 Array* const new_source, 491 ComputeArrayForReshape(new_source_shape, operand->source())); 492 TF_ASSIGN_OR_RETURN( 493 Array* const new_indices, 494 ComputeArrayForReshape(new_indices_shape, operand->indices())); 495 496 // Build the new output dims while keeping track of the degenerate dims that 497 // will no longer be present. 498 DimensionVector new_output_dims; 499 int64 degenerate_dims_seen = 0; 500 for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { 501 if (shape.dimensions(i) == 1) { 502 degenerate_dims_seen++; 503 } else if (absl::c_linear_search(operand->output_dims(), i)) { 504 new_output_dims.push_back(i - degenerate_dims_seen); 505 } 506 } 507 508 // Similarly, build the new source dim while keeping track of the degenerate 509 // dims that will no longer be present. 510 int64 degenerate_dims_before_source_dim = 511 std::count(source_shape.dimensions().begin(), 512 source_shape.dimensions().begin() + operand->source_dim(), 1); 513 int64 new_source_dim = 514 operand->source_dim() - degenerate_dims_before_source_dim; 515 516 return ConstructScalarIndexedArray( 517 new_source, new_indices, new_source_dim, 518 InlinedVectorToVector(new_output_dims), 519 StripDegenerateDimensions(operand->shape())); 520 } 521 522 StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::ReshapeToAddDegenerateDims( 523 ScalarIndexedArray* operand, absl::Span<const int64> degenerate_dims) { 524 if (degenerate_dims.empty()) { 525 return operand; 526 } 527 528 CHECK(!ShapeUtil::HasDegenerateDimensions(operand->shape())); 529 530 DimensionVector new_output_dims = [&]() { 531 // To make things easy we use a "scratch" buffer of bools where the i'th 532 // element is true iff the i'th component of the result index is an output 533 // index. 534 535 absl::InlinedVector<bool, 6> output_dims_bitvector( 536 operand->shape().dimensions_size()); 537 for (int64 output_dim : operand->output_dims()) { 538 output_dims_bitvector[output_dim] = true; 539 } 540 541 for (int64 degenerate_dim : degenerate_dims) { 542 InsertAt(&output_dims_bitvector, degenerate_dim, false); 543 } 544 545 DimensionVector result; 546 result.reserve(operand->output_dims().size()); 547 for (int64 i = 0, e = output_dims_bitvector.size(); i < e; i++) { 548 if (output_dims_bitvector[i]) { 549 result.push_back(i); 550 } 551 } 552 553 return result; 554 }(); 555 556 DimensionVector new_result_shape_dims; 557 absl::c_copy(operand->shape().dimensions(), 558 std::back_inserter(new_result_shape_dims)); 559 for (int64 degenerate_dim : degenerate_dims) { 560 InsertAt(&new_result_shape_dims, degenerate_dim, 1); 561 } 562 563 DimensionVector new_source_shape_dims = new_result_shape_dims; 564 for (int64 output_dim : new_output_dims) { 565 EraseAt(&new_source_shape_dims, output_dim); 566 } 567 568 int64 new_source_dim = [&]() { 569 for (int i = 0, e = new_source_shape_dims.size(); i < e; i++) { 570 int64 non_degenerate_dims_seen = 0; 571 if (non_degenerate_dims_seen == operand->source_dim()) { 572 return i; 573 } 574 if (new_source_shape_dims[new_source_dim] != 1) { 575 non_degenerate_dims_seen++; 576 } 577 } 578 LOG(FATAL) << "Did not find source dim in " << ToString(operand); 579 }(); 580 581 int64 source_dim_size = 582 operand->source()->shape().dimensions(operand->source_dim()); 583 InsertAt(&new_source_shape_dims, /*index=*/new_source_dim, 584 /*value=*/source_dim_size); 585 586 Shape new_source_shape = ShapeUtil::MakeShape(operand->shape().element_type(), 587 new_source_shape_dims); 588 Shape new_result_shape = ShapeUtil::MakeShape(operand->shape().element_type(), 589 new_result_shape_dims); 590 591 TF_ASSIGN_OR_RETURN( 592 Array* const new_source, 593 ComputeArrayForReshape(new_source_shape, operand->source())); 594 return ConstructScalarIndexedArray( 595 new_source, operand->indices(), new_source_dim, 596 InlinedVectorToVector(new_output_dims), new_result_shape); 597 } 598 599 StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::FoldReshapeOfGather( 600 const Shape& shape, ScalarIndexedConstantArray* operand) { 601 VLOG(3) << "FoldReshapeOfGather(" << ToString(operand) << ")"; 602 603 // To make things easier on ourselves, instead of directly trying to fold the 604 // reshape of `operand` to `shape`, we call 605 // `FoldReshapeOfGatherNoDegenerateDims` on shapes without degenerate dims and 606 // handle the degenerate dimensions here by inserting reshapes. 607 608 TF_ASSIGN_OR_RETURN(ScalarIndexedArray* const operand_without_degenerate_dims, 609 ReshapeToRemoveDegenerateDims(operand)); 610 611 Shape output_shape_without_degenerate_dims = StripDegenerateDimensions(shape); 612 TF_ASSIGN_OR_RETURN( 613 ScalarIndexedArray* const folded_reshape_without_degenerate_dims, 614 FoldReshapeOfGatherNoDegenerateDims( 615 output_shape_without_degenerate_dims, 616 operand_without_degenerate_dims->as<ScalarIndexedConstantArray>())); 617 618 if (folded_reshape_without_degenerate_dims == nullptr) { 619 return nullptr; 620 } 621 622 DimensionVector degenerate_result_dims; 623 for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { 624 if (shape.dimensions(i) == 1) { 625 degenerate_result_dims.push_back(i); 626 } 627 } 628 629 return ReshapeToAddDegenerateDims(folded_reshape_without_degenerate_dims, 630 degenerate_result_dims); 631 } 632 633 StatusOr<ScalarIndexedArray*> 634 IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( 635 const Shape& shape, ScalarIndexedConstantArray* scalar_indexed) { 636 VLOG(3) << "FoldReshapeOfGatherNoDegenerateDims(" << ToString(scalar_indexed) 637 << ")"; 638 CHECK(!ShapeUtil::HasDegenerateDimensions(shape)); 639 CHECK(!ShapeUtil::HasDegenerateDimensions(scalar_indexed->shape())); 640 641 // Try to fold Reshape(ScalarIndexed(Const, Indices)) 642 // => ScalarIndexed(Const', Indices) 643 // 644 // We can view the reshape and the scalar-indexed operations as functions that 645 // map an output index (i.e. an index into the result) to an input index 646 // (i.e. an index into the operand). The key idea used here is that the 647 // output-to-input mapping for some reshape operations may "pass through" some 648 // output dimensions into the input space unchanged -- i.e. there may exist 649 // output dimension "O" and input dimension "I" such that OutputIndex[O] is 650 // always == InputIndexForReshape(OutputIndex)[I]. If these pass-through 651 // dimensions in the input space of the reshape happen to be include all the 652 // output dimensions for the scalar-indexed node then, roughly, the following 653 // holds: 654 // 655 // SourceIndexOfScalarIndexed(SourceIndexOfReshape(Idx)) 656 // == SourceIndexOfScalarIndexed(SourceIndexOfReshape(Ps ++ Qs)) 657 // 658 // Where Ps are the set of the pass-through components of Idx that are 659 // also the output dims of the scalar-indexed node, and Qs are the rest. 660 // For brevity, we're playing fast and loose with the notation here -- we 661 // don't literally require Idx to be a concatenation of Ps and Qs, as 662 // suggested by the "++". 663 // 664 // == SourceIndexOfScalarIndexed(Ps ++ SourceIndexOfReshape(Qs)) 665 // 666 // Again, we're playing fast and loose with the notation around "++". 667 // Generally this ++ will be a different function that the ++ in the 668 // previous step. 669 // 670 // If the scalar-indexed node has a constant as the source then the 671 // SourceIndexOfReshape function can be "folded into" the constant itself by 672 // reshaping it, leaving us with: 673 // 674 // == SourceIndexOfScalarIndexed(Ps ++ Qs) 675 // == SourceIndexOfScalarIndexed(Idx) 676 // 677 // which is just a scalar-indexed node (with parameters different from the 678 // scalar-indexed node we started with) with a reshaped constant as the 679 // source. 680 // 681 // We can't fold SourceIndexOfReshape into the constant without introducing 682 // another precondition: since the new scalar-indexed node will have a 683 // reshaped (constant) array as its source it will, in general, have a 684 // different source dimension than the original scalar-indexed node. This 685 // source dimension will have to be a passthrough dimension of the 686 // SourceIndexOfReshape indexing function that is folded into the source. And 687 // such a dimension need not exist so this is a non-trivial precondition. 688 689 std::vector<ReshapePassthroughDimPair> reshape_passthrough_dims = 690 ComputeReshapePassthroughDimPairs( 691 /*operand_shape=*/AsInt64Slice(scalar_indexed->shape().dimensions()), 692 /*result_shape=*/AsInt64Slice(shape.dimensions())); 693 694 auto is_reshape_passthrough_operand_dim = [&](int64 operand_dim) { 695 return IsReshapePassthroughOperandDim(reshape_passthrough_dims, 696 operand_dim); 697 }; 698 699 if (!absl::c_all_of(scalar_indexed->output_dims(), 700 is_reshape_passthrough_operand_dim)) { 701 VLOG(3) << "Not all output dims are passthrough dims " 702 << ToString(scalar_indexed); 703 return nullptr; 704 } 705 706 // To compute the shape of the source for the new scalar-indexed node we're 707 // going to create, we first "undo" the scalar-indexed operation. 708 std::vector<int64> new_scalar_indexed_source_shape(shape.dimensions().begin(), 709 shape.dimensions().end()); 710 for (int64 i = scalar_indexed->output_dims().size() - 1; i >= 0; i--) { 711 int64 output_dim = scalar_indexed->output_dims()[i]; 712 int64 output_dim_after_reshape = MapPassthroughOperandDimToResultDim( 713 reshape_passthrough_dims, output_dim); 714 EraseAt(&new_scalar_indexed_source_shape, output_dim_after_reshape); 715 } 716 717 // After this, we need to add in the dimension that will be the source 718 // dimension for the new scalar-indexed node. A scalar-indexed node "removes" 719 // the source dimensions and "adds" the output dimensions, so to get back to 720 // the shape for the *source* of the scalar-indexed node we need to remove the 721 // output dims (which we did above) and then add back the source dim (which we 722 // are about to do below): 723 724 const Shape& scalar_indexed_source_shape = scalar_indexed->source()->shape(); 725 726 int64 source_dim_for_new_scalar_indexed_node = 727 FindSourcePositionForPassthroughResultDim( 728 /*operand_shape=*/AsInt64Slice( 729 scalar_indexed_source_shape.dimensions()), 730 /*result_shape=*/new_scalar_indexed_source_shape, 731 scalar_indexed->source_dim()); 732 733 // We may not be able to find a source dim for the new scalar-indexed node. 734 // For instance consider: 735 // 736 // operand = s32[3,5,2] constant({...}) 737 // indices = s32[7] parameter(0) 738 // gather = s32[3,2,7] gather(operand, indices), 739 // offset_dims={0,1}, 740 // collapsed_slice_dims={1}, 741 // start_index_map={1}, 742 // index_vector_dim=1, 743 // slice_sizes={3,1,2} 744 // reshape = s32[6,7] reshape(gather) 745 // 746 // In this case the gather maps to: 747 // (scalar-indexed-const (constant s32[3,5,2]) %indices 1->[2]) 748 // 749 // and the reshape passes through dimension 2 from its input into dimension 1 750 // in its output. However, we can't rewrite the reshape as a scalar-indexed 751 // node because then we'd have to reshape the [3,5,2] `operand` array to 752 // [6,5], but then dimension 1 of the reshaped [6,5] array indexes differently 753 // (a.k.a. isn't pass-through) than the [3,5,2] array. 754 755 if (source_dim_for_new_scalar_indexed_node == -1) { 756 VLOG(3) << "Could not compute the source dim for the new scalar indexed " 757 "node: scalar_indexed_source_shape = [" 758 << StrJoin(scalar_indexed_source_shape.dimensions(), ",") 759 << "] and new_scalar_indexed_source_shape = [" 760 << StrJoin(new_scalar_indexed_source_shape, ",") << "]"; 761 return nullptr; 762 } 763 764 InsertAt( 765 &new_scalar_indexed_source_shape, source_dim_for_new_scalar_indexed_node, 766 scalar_indexed_source_shape.dimensions(scalar_indexed->source_dim())); 767 768 CHECK_EQ(absl::c_accumulate(new_scalar_indexed_source_shape, 1LL, 769 std::multiplies<int64>()), 770 ShapeUtil::ElementsIn(scalar_indexed_source_shape)); 771 772 CHECK(IsReshapePassthroughOperandDim( 773 ComputeReshapePassthroughDimPairs( 774 /*operand_shape=*/AsInt64Slice( 775 scalar_indexed_source_shape.dimensions()), 776 /*result_shape=*/new_scalar_indexed_source_shape), 777 scalar_indexed->source_dim())); 778 779 auto map_passthrough_operand_dim_to_result_dim = [&](int64 result_dim) { 780 return MapPassthroughOperandDimToResultDim(reshape_passthrough_dims, 781 result_dim); 782 }; 783 784 std::vector<int64> output_dims_for_new_scalar_indexed_node; 785 absl::c_transform(scalar_indexed->output_dims(), 786 std::back_inserter(output_dims_for_new_scalar_indexed_node), 787 map_passthrough_operand_dim_to_result_dim); 788 789 TF_ASSIGN_OR_RETURN(const Literal* new_scalar_indexed_source_literal, 790 TakeOwnership(scalar_indexed->literal().Reshape( 791 new_scalar_indexed_source_shape))); 792 TF_ASSIGN_OR_RETURN( 793 Array * new_scalar_indexed_source, 794 ComputeArrayForConstant(*new_scalar_indexed_source_literal)); 795 796 return ConstructScalarIndexedArray( 797 new_scalar_indexed_source, scalar_indexed->indices(), 798 source_dim_for_new_scalar_indexed_node, 799 output_dims_for_new_scalar_indexed_node, shape); 800 } 801 802 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForReshape( 803 const Shape& shape, Array* operand) { 804 if (ShapeUtil::Compatible(operand->shape(), shape)) { 805 return operand; 806 } 807 808 if (auto* scalar_indexed = 809 dynamic_cast<ScalarIndexedConstantArray*>(operand)) { 810 TF_ASSIGN_OR_RETURN(Analysis::Array * reshape_folded_into_gather, 811 FoldReshapeOfGather(shape, scalar_indexed)); 812 if (reshape_folded_into_gather) { 813 return reshape_folded_into_gather; 814 } 815 } 816 817 if (auto* constant_array = dynamic_cast<ConstantArray*>(operand)) { 818 TF_ASSIGN_OR_RETURN(Literal* const new_literal, 819 TakeOwnership(constant_array->literal()->Reshape( 820 AsInt64Slice(shape.dimensions())))); 821 return Construct<ConstantArray>(new_literal); 822 } 823 824 return Construct<ReshapedArray>(operand, shape); 825 } 826 827 StatusOr<Analysis::Array*> 828 IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, 829 Array* lhs, 830 Array* rhs) { 831 // Try to fold BinaryOp(Broadcast(Const0), ScalarIndexed(Const1, Indices)) 832 // => ScalarIndexed(BinaryOp(Broadcast'(Const0), Const1), Indices) 833 // 834 // We can do this if every output dimension from the scalar-indexed node is a 835 // broadcasted dimension for the broadcast node. Informally, the precondition 836 // means Broadcast(Const0)[IDX] is solely a function of the components of IDX 837 // that are not output-dims for the scalar-indexed node. In other words, for 838 // every assignment to the non-output dims in IDX we have a "constant" LHS to 839 // the BinaryOp. This transform propagates this "constant" to the source for 840 // the scalar-indexed node. 841 842 ScalarIndexedConstantArray* lhs_scalar_indexed_const = 843 dynamic_cast<ScalarIndexedConstantArray*>(lhs); 844 ScalarIndexedConstantArray* rhs_scalar_indexed_const = 845 dynamic_cast<ScalarIndexedConstantArray*>(rhs); 846 847 bool lhs_is_indexed; 848 849 // One of the operands must be scalar-indexed and the other must be a 850 // broadcast of a constant. 851 if (lhs_scalar_indexed_const && !rhs_scalar_indexed_const) { 852 lhs_is_indexed = true; 853 } else if (rhs_scalar_indexed_const && !lhs_scalar_indexed_const) { 854 lhs_is_indexed = false; 855 } else { 856 return nullptr; 857 } 858 859 ScalarIndexedConstantArray* scalar_indexed_const = 860 lhs_is_indexed ? lhs_scalar_indexed_const : rhs_scalar_indexed_const; 861 UnknownArray* candidate_broadcast_array = 862 dynamic_cast<UnknownArray*>(lhs_is_indexed ? rhs : lhs); 863 if (!candidate_broadcast_array || 864 candidate_broadcast_array->instruction().opcode() != 865 HloOpcode::kBroadcast) { 866 return nullptr; 867 } 868 869 const HloInstruction* broadcast_instr = 870 &candidate_broadcast_array->instruction(); 871 const HloInstruction* broadcast_const_operand = broadcast_instr->operand(0); 872 if (broadcast_const_operand->opcode() != HloOpcode::kConstant) { 873 return nullptr; 874 } 875 876 absl::Span<const int64> broadcast_dims = broadcast_instr->dimensions(); 877 auto is_broadcasted_dim = [&](int64 output_dim) { 878 return absl::c_find(broadcast_dims, output_dim) == broadcast_dims.end(); 879 }; 880 881 // All of the output dims must be "broadcasted" dims for the other operand. 882 if (!absl::c_all_of(scalar_indexed_const->output_dims(), 883 is_broadcasted_dim)) { 884 return nullptr; 885 } 886 887 // To figure out the broadcast dimensions for the (constant) source for the 888 // scalar-indexed node, we "simulate" the index transformation done by the 889 // existing broadcsat: 890 enum class IndexComponent { Broadcasted, NotBroadcasted }; 891 std::vector<IndexComponent> simulated_index( 892 broadcast_instr->shape().dimensions_size(), IndexComponent::Broadcasted); 893 for (int64 broadcast_dim : broadcast_dims) { 894 simulated_index[broadcast_dim] = IndexComponent::NotBroadcasted; 895 } 896 897 // The scalar-indexed node "removes" the source dim and "inserts" the output 898 // dims. We do the opposite here to undo the scalar-indexed operation. 899 absl::Span<const int64> output_dims = scalar_indexed_const->output_dims(); 900 for (int64 i = output_dims.size() - 1; i >= 0; --i) { 901 CHECK(simulated_index[output_dims[i]] == IndexComponent::Broadcasted); 902 EraseAt(&simulated_index, output_dims[i]); 903 } 904 905 InsertAt(&simulated_index, scalar_indexed_const->source_dim(), 906 IndexComponent::Broadcasted); 907 908 // new_inner_broadcast_dims holds the broadcast dimensions for the inner 909 // BinaryOp(Broadcast'(Const0), Const1). We now translate simulated_index to 910 // new_inner_broadcast_dims. 911 std::vector<int64> new_inner_broadcast_dims; 912 for (int64 i = 0; i < simulated_index.size(); i++) { 913 if (simulated_index[i] == IndexComponent::NotBroadcasted) { 914 new_inner_broadcast_dims.push_back(i); 915 } 916 } 917 918 // inner_broadcast_result is the Broadcast'(Const0) bit in 919 // BinaryOp(Broadcast'(Const0), Const1) 920 TF_ASSIGN_OR_RETURN( 921 Literal inner_broadcast_result, 922 broadcast_const_operand->literal().Broadcast( 923 scalar_indexed_const->source()->shape(), new_inner_broadcast_dims)); 924 925 // literal_for_new_source is BinaryOp(Broadcast'(Const0), Const1) 926 const Literal* literal_for_new_source; 927 if (lhs_is_indexed) { 928 TF_ASSIGN_OR_RETURN( 929 literal_for_new_source, 930 TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp( 931 opcode, scalar_indexed_const->literal(), inner_broadcast_result))); 932 } else { 933 TF_ASSIGN_OR_RETURN( 934 literal_for_new_source, 935 TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp( 936 opcode, inner_broadcast_result, scalar_indexed_const->literal()))); 937 } 938 939 ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source); 940 return Construct<ScalarIndexedConstantArray>( 941 new_source, scalar_indexed_const->indices(), 942 scalar_indexed_const->source_dim(), 943 std::vector<int64>(scalar_indexed_const->output_dims().begin(), 944 scalar_indexed_const->output_dims().end()), 945 scalar_indexed_const->shape()); 946 } 947 948 StatusOr<Analysis::Array*> 949 IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(HloOpcode opcode, 950 Array* operand) { 951 auto* scalar_indexed_const = 952 dynamic_cast<ScalarIndexedConstantArray*>(operand); 953 if (scalar_indexed_const == nullptr) { 954 return nullptr; 955 } 956 957 // Fold UnaryOp(ScalarIndexed(Const, Indices)) 958 // => ScalarIndexed(UnaryOp(Const), Indices) 959 960 TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source, 961 TakeOwnership(HloEvaluator{}.EvaluateElementwiseUnaryOp( 962 opcode, scalar_indexed_const->literal()))); 963 ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source); 964 return Construct<ScalarIndexedConstantArray>( 965 new_source, scalar_indexed_const->indices(), 966 scalar_indexed_const->source_dim(), 967 ArraySliceToVector(scalar_indexed_const->output_dims()), 968 scalar_indexed_const->shape()); 969 } 970 971 namespace { 972 973 // Returns the non-contracting non-batch dimension (as per `contracting_dims` 974 // and `batch_dims`) if there is exactly one, otherwise returns nullopt. 975 absl::optional<int64> GetOnlyNonContractingNonBatchDim( 976 int64 rank, absl::Span<const int64> contracting_dims, 977 absl::Span<const int64> batch_dims) { 978 absl::optional<int64> result; 979 for (int64 dim = 0; dim < rank; dim++) { 980 if (!absl::c_linear_search(contracting_dims, dim) && 981 !absl::c_linear_search(batch_dims, dim)) { 982 if (result.has_value()) { 983 return absl::nullopt; 984 } 985 result = dim; 986 } 987 } 988 return result; 989 } 990 991 // Returns true if `indexed_array`, which is either the LHS or the RHS of a Dot 992 // HLO, can be folded into the dot operation. For now these conditions are both 993 // necessary and sufficient. 994 // 995 // `tag` describes the caller. Used only for logging. 996 // 997 // `contracting_dims` and `batch_dims` are the contracting and batch dimensions 998 // of whatever operand `indexed_array` is to the dot (LHS or RHS). 999 bool CanFoldDotIntoIndexedArray( 1000 absl::string_view tag, Analysis::ScalarIndexedConstantArray* indexed_array, 1001 absl::Span<const int64> contracting_dims, 1002 absl::Span<const int64> batch_dims) { 1003 absl::optional<int64> non_contracting_non_batch_dim = 1004 GetOnlyNonContractingNonBatchDim(indexed_array->shape().rank(), 1005 contracting_dims, batch_dims); 1006 if (!non_contracting_non_batch_dim.has_value()) { 1007 VLOG(3) << tag << ": multiple or no non-contracting non-batch dimensions"; 1008 return false; 1009 } 1010 1011 if (indexed_array->output_dims().size() != 1 || 1012 indexed_array->output_dims()[0] != *non_contracting_non_batch_dim) { 1013 VLOG(3) << tag << ": output dims != the lhs non-contracting non-batch dim"; 1014 return false; 1015 } 1016 1017 int64 indexed_array_rank = indexed_array->shape().rank(); 1018 if (indexed_array->source_dim() < (indexed_array_rank - 2)) { 1019 // This restriction can be lifted by inserting reshape nodes. 1020 VLOG(3) << tag 1021 << ": source dim is not in the low two dims, won't be able to form " 1022 "a matmul"; 1023 return false; 1024 } 1025 1026 return true; 1027 } 1028 1029 } // namespace 1030 1031 StatusOr<Analysis::Array*> 1032 IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( 1033 const Shape& shape, const DotDimensionNumbers& dim_numbers, 1034 const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs, 1035 ConstantArray* rhs) { 1036 VLOG(3) << "ComputeArrayForDotWithIndexedLhs(" << ToString(lhs) << " " 1037 << ToString(rhs); 1038 if (!CanFoldDotIntoIndexedArray( 1039 "ComputeArrayForDotWithIndexedLhs", lhs, /*contracting_dims=*/ 1040 AsInt64Slice(dim_numbers.lhs_contracting_dimensions()), 1041 /*batch_dims=*/AsInt64Slice(dim_numbers.lhs_batch_dimensions()))) { 1042 return nullptr; 1043 } 1044 1045 int64 lhs_rank = lhs->shape().rank(); 1046 DotDimensionNumbers new_dim_numbers = dim_numbers; 1047 new_dim_numbers.set_lhs_contracting_dimensions( 1048 0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1)); 1049 1050 TF_ASSIGN_OR_RETURN( 1051 Literal * literal_for_new_source, 1052 TakeOwnership(HloEvaluator{}.EvaluateDotOp( 1053 new_dim_numbers, precision_config, lhs->literal(), *rhs->literal()))); 1054 1055 // The new source dimension is wherever the non-batch non-contracting LHS 1056 // dimension "went". 1057 int64 new_source_dim = dim_numbers.lhs_batch_dimensions_size() + 1058 dim_numbers.rhs_batch_dimensions_size(); 1059 1060 ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source); 1061 return Construct<ScalarIndexedConstantArray>( 1062 new_source, lhs->indices(), new_source_dim, 1063 ArraySliceToVector(lhs->output_dims()), shape); 1064 } 1065 1066 StatusOr<Analysis::Array*> 1067 IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( 1068 const Shape& shape, const DotDimensionNumbers& dim_numbers, 1069 const PrecisionConfig& precision_config, ConstantArray* lhs, 1070 ScalarIndexedConstantArray* rhs) { 1071 VLOG(3) << "ComputeArrayForDotWithIndexedRhs(" << ToString(lhs) << " " 1072 << ToString(rhs); 1073 if (!CanFoldDotIntoIndexedArray( 1074 "ComputeArrayForDotWithIndexedRhs", rhs, /*contracting_dims=*/ 1075 AsInt64Slice(dim_numbers.rhs_contracting_dimensions()), 1076 /*batch_dims=*/AsInt64Slice(dim_numbers.rhs_batch_dimensions()))) { 1077 return nullptr; 1078 } 1079 1080 int64 rhs_rank = rhs->shape().rank(); 1081 1082 DotDimensionNumbers new_dim_numbers = dim_numbers; 1083 new_dim_numbers.set_rhs_contracting_dimensions( 1084 0, rhs->source_dim() == (rhs_rank - 1) ? (rhs_rank - 2) : (rhs_rank - 1)); 1085 1086 TF_ASSIGN_OR_RETURN( 1087 Literal * literal_for_new_source, 1088 TakeOwnership(HloEvaluator{}.EvaluateDotOp( 1089 new_dim_numbers, precision_config, *lhs->literal(), rhs->literal()))); 1090 1091 // The new source dimension is wherever the non-batch non-contracting RHS 1092 // dimension "went". 1093 int64 new_source_dim = dim_numbers.lhs_batch_dimensions_size() + 1094 dim_numbers.rhs_batch_dimensions_size() + 1; 1095 1096 ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source); 1097 return Construct<ScalarIndexedConstantArray>( 1098 new_source, rhs->indices(), new_source_dim, 1099 ArraySliceToVector(rhs->output_dims()), shape); 1100 } 1101 1102 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot( 1103 const Shape& shape, const DotDimensionNumbers& dim_numbers, 1104 const PrecisionConfig& precision_config, Array* lhs, Array* rhs) { 1105 // Intuitively, if 1106 // 1107 // - The LHS of a dot product is a gathered sequence of rows from a constant 1108 // array (i.e. LHS[I,J] = Const[Indices[I],J]) and the RHS is a constant 1109 // 1110 // OR 1111 // 1112 // - If the RHS of a dot product is a gathered sequence of columns from a 1113 // constant array (i.e. RHS[I,J] = Const[I, Indices[J]]) and the LHS is a 1114 // constant 1115 // 1116 // then the result of the dot product itself is a gather from a constant 1117 // array. E.g. Dot(LHS, ConstRhs) where LHS[I,J] = Const[Indices[I],J] can be 1118 // rewritten as Result where Result[I,J] = Dot(Const, ConstRhs)[Indices[I], 1119 // J]. 1120 // 1121 // We do a general version of this rewrite here. 1122 VLOG(3) << "ComputeArrayForDot(" << ToString(lhs) << " " << ToString(rhs); 1123 if (auto* lhs_indexed_array = 1124 dynamic_cast<ScalarIndexedConstantArray*>(lhs)) { 1125 if (auto* rhs_constant = dynamic_cast<ConstantArray*>(rhs)) { 1126 return ComputeArrayForDotWithIndexedLhs(shape, dim_numbers, 1127 precision_config, 1128 lhs_indexed_array, rhs_constant); 1129 } 1130 } 1131 1132 if (auto* rhs_indexed_array = 1133 dynamic_cast<ScalarIndexedConstantArray*>(rhs)) { 1134 if (auto* lhs_constant = dynamic_cast<ConstantArray*>(lhs)) { 1135 return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers, 1136 precision_config, lhs_constant, 1137 rhs_indexed_array); 1138 } 1139 } 1140 1141 return nullptr; 1142 } 1143 1144 absl::string_view IndexedArrayAnalysisPrinterPass::name() const { 1145 return "indexed-array-analysis-printer-pass"; 1146 } 1147 1148 StatusOr<bool> IndexedArrayAnalysisPrinterPass::Run(HloModule* module) { 1149 if (!VLOG_IS_ON(2)) { 1150 return false; 1151 } 1152 1153 IndexedArrayAnalysis analysis; 1154 for (auto* computation : module->MakeNonfusionComputations()) { 1155 for (auto* instr : computation->instructions()) { 1156 TF_ASSIGN_OR_RETURN(Analysis::Array * t, analysis.GetArrayFor(instr)); 1157 if (!dynamic_cast<UnknownArray*>(t) && !dynamic_cast<ConstantArray*>(t)) { 1158 VLOG(2) << instr->ToString() << " -> " << analysis.ToString(t); 1159 } 1160 } 1161 } 1162 1163 return false; 1164 } 1165 1166 } // namespace xla 1167