Home | History | Annotate | Download | only in service
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
     17 
     18 #include "absl/algorithm/container.h"
     19 #include "absl/container/flat_hash_map.h"
     20 #include "absl/container/flat_hash_set.h"
     21 #include "absl/container/inlined_vector.h"
     22 #include "absl/strings/str_cat.h"
     23 #include "absl/strings/str_join.h"
     24 #include "absl/types/optional.h"
     25 #include "tensorflow/compiler/xla/map_util.h"
     26 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
     27 #include "tensorflow/compiler/xla/util.h"
     28 
     29 namespace xla {
     30 
     31 namespace {
     32 using Analysis = IndexedArrayAnalysis;
     33 using UnknownArray = Analysis::UnknownArray;
     34 using ConstantArray = Analysis::ConstantArray;
     35 using ReshapedArray = Analysis::ReshapedArray;
     36 using ScalarIndexedArray = Analysis::ScalarIndexedArray;
     37 using absl::StrJoin;
     38 }  // namespace
     39 
     40 string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) {
     41   switch (root->kind()) {
     42     case Array::kUnknown: {
     43       auto* unknown_tensor = root->as<UnknownArray>();
     44       return absl::StrCat("%", unknown_tensor->instruction().name());
     45     }
     46 
     47     case Array::kConstant: {
     48       if (print_constants) {
     49         string contents = root->as<ConstantArray>()->literal()->ToString();
     50         return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()),
     51                             " ", contents, ")");
     52       }
     53       return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()),
     54                           ")");
     55     }
     56 
     57     case Array::kReshaped: {
     58       ReshapedArray* reshaped_array = root->as<ReshapedArray>();
     59       return absl::StrCat(
     60           "(reshape ", ToString(reshaped_array->operand(), print_constants),
     61           " to ", ShapeUtil::HumanString(reshaped_array->shape()), ")");
     62     }
     63 
     64     case Array::kScalarIndexedConstant:
     65     case Array::kScalarIndexed: {
     66       auto* indexed_array = root->as<ScalarIndexedArray>();
     67       string name = root->kind() == Array::kScalarIndexedConstant
     68                         ? "scalar-indexed-const"
     69                         : "scalar-indexed";
     70       return absl::StrCat(
     71           "(", name, " ", ToString(indexed_array->source(), print_constants),
     72           " ", ToString(indexed_array->indices(), print_constants), " ",
     73           indexed_array->source_dim(), "->[",
     74           StrJoin(indexed_array->output_dims(), ","), "])");
     75     }
     76   }
     77 }
     78 
     79 StatusOr<Analysis::Array*> IndexedArrayAnalysis::GetArrayFor(
     80     const HloInstruction* instr) {
     81   auto it = cache_.find(instr);
     82   if (it != cache_.end()) {
     83     return it->second;
     84   }
     85 
     86   TF_RETURN_IF_ERROR(TraverseAndPopulateCache(instr));
     87   return FindOrDie(cache_, instr);
     88 }
     89 
     90 Status IndexedArrayAnalysis::TraverseAndPopulateCache(
     91     const HloInstruction* root) {
     92   // Depth first search over the DAG, invoking ComputeArrayFor in post order.
     93   // The HLO instructions already in the cache are considered leaves.
     94 
     95   absl::InlinedVector<const HloInstruction*, 4> stack;
     96 
     97   enum DfsState { kDiscovered, kVisited };
     98   absl::flat_hash_map<const HloInstruction*, DfsState> dfs_state_map;
     99 
    100   stack.push_back(root);
    101   InsertOrDie(&dfs_state_map, root, kDiscovered);
    102 
    103   do {
    104     const HloInstruction* instr = stack.back();
    105     if (cache_.contains(instr)) {
    106       stack.pop_back();
    107       continue;
    108     }
    109 
    110     switch (FindOrDie(dfs_state_map, instr)) {
    111       case kDiscovered: {
    112         for (const HloInstruction* operand : instr->operands()) {
    113           if (!cache_.contains(operand)) {
    114             stack.push_back(operand);
    115             CHECK(!dfs_state_map.contains(operand) ||
    116                   dfs_state_map[operand] == kDiscovered);
    117             dfs_state_map[operand] = kDiscovered;
    118           }
    119         }
    120         dfs_state_map[instr] = kVisited;
    121         break;
    122       }
    123 
    124       case kVisited:
    125         stack.pop_back();
    126         TF_ASSIGN_OR_RETURN(Array * array, ComputeArrayFor(instr));
    127         InsertOrDie(&cache_, instr, array);
    128         break;
    129     }
    130   } while (!stack.empty());
    131 
    132   return Status::OK();
    133 }
    134 
    135 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayFor(
    136     const HloInstruction* instr) {
    137   Array* computed_array;
    138   if (instr->IsElementwise() && instr->operand_count() == 1) {
    139     TF_ASSIGN_OR_RETURN(
    140         computed_array,
    141         ComputeArrayForElementwiseUnaryOp(
    142             instr->opcode(), FindOrDie(cache_, instr->operand(0))));
    143   } else if (instr->IsElementwise() && instr->operand_count() == 2) {
    144     TF_ASSIGN_OR_RETURN(
    145         computed_array,
    146         ComputeArrayForElementwiseBinaryOp(
    147             instr->opcode(), FindOrDie(cache_, instr->operand(0)),
    148             FindOrDie(cache_, instr->operand(1))));
    149   } else if (instr->opcode() == HloOpcode::kConstant) {
    150     TF_ASSIGN_OR_RETURN(computed_array,
    151                         ComputeArrayForConstant(instr->literal()));
    152   } else if (instr->opcode() == HloOpcode::kGather) {
    153     TF_ASSIGN_OR_RETURN(
    154         computed_array,
    155         ComputeArrayForGather(instr->shape(), instr->gather_dimension_numbers(),
    156                               instr->gather_slice_sizes(),
    157                               FindOrDie(cache_, instr->operand(0)),
    158                               FindOrDie(cache_, instr->operand(1))));
    159   } else if (instr->opcode() == HloOpcode::kReshape) {
    160     TF_ASSIGN_OR_RETURN(
    161         computed_array,
    162         ComputeArrayForReshape(instr->shape(),
    163                                FindOrDie(cache_, instr->operand(0))));
    164   } else if (instr->opcode() == HloOpcode::kDot) {
    165     TF_ASSIGN_OR_RETURN(
    166         computed_array,
    167         ComputeArrayForDot(instr->shape(), instr->dot_dimension_numbers(),
    168                            instr->precision_config(),
    169                            FindOrDie(cache_, instr->operand(0)),
    170                            FindOrDie(cache_, instr->operand(1))));
    171   } else {
    172     computed_array = nullptr;
    173   }
    174 
    175   if (!computed_array) {
    176     computed_array = Construct<UnknownArray>(instr);
    177   }
    178 
    179   return computed_array;
    180 }
    181 
    182 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForConstant(
    183     const Literal& literal) {
    184   return Construct<ConstantArray>(&literal);
    185 }
    186 
    187 StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::FoldGatherOfGather(
    188     ScalarIndexedArray* source, Array* indices, int64 source_dim,
    189     absl::Span<const int64> output_dims, Shape shape) {
    190   // We want to transform Gather(Gather(A, X), Y) => Gather(A, Gather(X, Y)).
    191   // `source` is the inner Gather(A, X).
    192 
    193   Array* a = source->source();
    194   Array* x = source->indices();
    195   Array* y = indices;
    196 
    197   // This bit is slightly tricky, so we do a naive "simulation" of the two
    198   // consecutive gather operations to infer what the composed gather should look
    199   // like.
    200 
    201   enum class IndexComponent { Ungathered, GatheredFirst, GatheredSecond };
    202 
    203   std::vector<IndexComponent> simulated_index(a->shape().dimensions_size(),
    204                                               IndexComponent::Ungathered);
    205 
    206   // Simulate the first gather.
    207   EraseAt(&simulated_index, source->source_dim());
    208   for (int64 gather_dim : source->output_dims()) {
    209     simulated_index.insert(simulated_index.begin() + gather_dim,
    210                            IndexComponent::GatheredFirst);
    211   }
    212 
    213   // Simulate the second gather.
    214   EraseAt(&simulated_index, source_dim);
    215   for (int64 output_dim : output_dims) {
    216     simulated_index.insert(simulated_index.begin() + output_dim,
    217                            IndexComponent::GatheredSecond);
    218   }
    219 
    220   int64 source_dim_for_index_array =
    221       FindIndex(source->output_dims(), source_dim);
    222   CHECK_NE(source_dim_for_index_array, source->output_dims().size());
    223 
    224   std::vector<int64> output_dims_for_index_array;
    225   int64 gathered_index_components_seen = 0;
    226   for (IndexComponent simulation_dim : simulated_index) {
    227     if (simulation_dim == IndexComponent::GatheredSecond) {
    228       output_dims_for_index_array.push_back(gathered_index_components_seen);
    229     }
    230     if (simulation_dim != IndexComponent::Ungathered) {
    231       gathered_index_components_seen++;
    232     }
    233   }
    234 
    235   std::vector<int64> dim_sizes_for_composed_index;
    236   std::vector<int64> output_dims_for_new_gather;
    237   for (int64 i = 0, e = simulated_index.size(); i < e; i++) {
    238     if (simulated_index[i] != IndexComponent::Ungathered) {
    239       dim_sizes_for_composed_index.push_back(shape.dimensions(i));
    240       output_dims_for_new_gather.push_back(i);
    241     }
    242   }
    243 
    244   Array* inner_indices = ConstructScalarIndexedArray(
    245       x, y, source_dim_for_index_array, output_dims_for_index_array,
    246       ShapeUtil::MakeShape(x->shape().element_type(),
    247                            dim_sizes_for_composed_index));
    248   return ConstructScalarIndexedArray(a, inner_indices, source->source_dim(),
    249                                      output_dims_for_new_gather,
    250                                      std::move(shape));
    251 }
    252 
    253 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForGather(
    254     const Shape& shape, const GatherDimensionNumbers& dim_numbers,
    255     absl::Span<const int64> slice_sizes, Array* source, Array* indices) {
    256   if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) {
    257     VLOG(3) << "ComputeArrayForGather: indices are not scalar";
    258     return nullptr;
    259   }
    260 
    261   CHECK_EQ(dim_numbers.start_index_map_size(), 1);
    262 
    263   // We can also handle dim_numbers.collapsed_slice_dims_size() == 0 here,
    264   // should it become relevant.
    265 
    266   if (dim_numbers.collapsed_slice_dims_size() != 1 ||
    267       dim_numbers.collapsed_slice_dims(0) != dim_numbers.start_index_map(0)) {
    268     VLOG(3) << "ComputeArrayForGather: gather operations must elide "
    269                "start_index_map[0] and "
    270                "start_index_map[0] only";
    271     return nullptr;
    272   }
    273 
    274   // ScalarIndexedArray cannot represent gathers that "slice" along some
    275   // dimensions -- for instance it cannot represent a gather that picks 5 [2,3]
    276   // arrays from an array of size [7,4,6].  We check that condition down below:
    277 
    278   for (int64 i = 0, e = source->shape().dimensions_size(); i < e; i++) {
    279     if (i != dim_numbers.collapsed_slice_dims(0) &&
    280         source->shape().dimensions(i) != slice_sizes[i]) {
    281       VLOG(3) << "ComputeArrayForGather: slice_sizes[" << i
    282               << "] != source->shape().dimensions(" << i << ") -- "
    283               << source->shape().dimensions(i) << " vs. " << slice_sizes[i]
    284               << " with dim_numbers.collapsed_slice_dims(0) = "
    285               << dim_numbers.collapsed_slice_dims(0);
    286       return nullptr;
    287     }
    288   }
    289 
    290   int64 source_dim = dim_numbers.start_index_map(0);
    291   std::vector<int64> output_dims;
    292   for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) {
    293     if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) {
    294       output_dims.push_back(i);
    295     }
    296   }
    297 
    298   if (auto* indexed = dynamic_cast<ScalarIndexedArray*>(source)) {
    299     if (absl::c_linear_search(indexed->output_dims(), source_dim)) {
    300       return FoldGatherOfGather(indexed, indices, source_dim, output_dims,
    301                                 shape);
    302     }
    303   } else if (auto* constant = dynamic_cast<ConstantArray*>(source)) {
    304     return Construct<ScalarIndexedConstantArray>(constant, indices, source_dim,
    305                                                  output_dims, shape);
    306   }
    307 
    308   return Construct<ScalarIndexedArray>(source, indices, source_dim, output_dims,
    309                                        shape);
    310 }
    311 
    312 namespace {
    313 // Returns an index into `values` such that the product of the range
    314 // [values.begin()+index, values.end()) is equal to `product`.  If there is no
    315 // such index, return -1.  All integers in `values` must be positive.
    316 int64 FindSuffixWithProduct(absl::Span<const int64> values, int64 product) {
    317   DCHECK(absl::c_all_of(values, [](int64 value) { return value > 0; }));
    318 
    319   int64 current_product = 1;
    320   int64 i;
    321   for (i = values.size() - 1; i >= 0 && product > current_product; --i) {
    322     current_product *= values[i];
    323   }
    324 
    325   if (product == current_product) {
    326     return i + 1;
    327   }
    328 
    329   return -1;
    330 }
    331 
    332 struct ReshapePassthroughDimPair {
    333   int64 result_dim;
    334   int64 operand_dim;
    335 };
    336 
    337 // Returns a set of dimension pairs such for all (result_dim, operand_dim) in
    338 // the set:
    339 //
    340 // output_index[result_dim] = SourceIndexOfReshape(output_index)[operand_dim]
    341 //
    342 // The returned vector of pairs is sorted in both the result_dim and the
    343 // operand_dim components.
    344 std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs(
    345     absl::Span<const int64> operand_shape,
    346     absl::Span<const int64> result_shape) {
    347   // A reshape can be seen as an index mapping from output index to input index:
    348   //
    349   // (i_0, ..., i_n) = f(o_0, ..., o_m)
    350   //
    351   // This function returns the pairs (j, k) for which the following invariant
    352   // holds for all indices in the shape:
    353   //
    354   //   o_j == i_k
    355   //
    356   // And this occurs when:
    357   //
    358   //    O_{j+1} * ... * O_n == I_{k+1} * ...  * I_m
    359   //
    360   // (where O_x are the sizes of the output shape and I_x are the sizes of the
    361   // input shape) and the size of the dimension j of the result is the same as
    362   // the size of dimension k in the operand.
    363   //
    364   // These conditions are sufficient because the Reshape HLO is spec'ed such
    365   // that the rightmost dimensions are always minor in the flattening and refine
    366   // operation.
    367 
    368   std::vector<ReshapePassthroughDimPair> result;
    369   int64 result_subarray_size = 1;
    370   for (int64 result_dim = result_shape.size() - 1; result_dim >= 0;
    371        --result_dim) {
    372     int64 candidate_operand_dim =
    373         FindSuffixWithProduct(operand_shape, result_subarray_size);
    374 
    375     // result_subarray_size does not include the elements in the current
    376     // `result_dim` dimension (we multiply in result_shape[result_dim] at the
    377     // end of loop body) so candidate_operand_dim can never be zero.
    378     CHECK_NE(candidate_operand_dim, 0)
    379         << "result_dim = " << result_dim
    380         << ", result_subarray_size = " << result_subarray_size
    381         << ", result_shape = [" << StrJoin(result_shape, ",") << "]"
    382         << ", operand_shape = [" << StrJoin(operand_shape, ",") << "]";
    383 
    384     if (candidate_operand_dim != -1 &&
    385         result_shape[result_dim] == operand_shape[candidate_operand_dim - 1]) {
    386       result.push_back({/*result_dim=*/result_dim,
    387                         /*operand_dim=*/candidate_operand_dim - 1});
    388     }
    389     result_subarray_size *= result_shape[result_dim];
    390   }
    391 
    392   absl::c_reverse(result);
    393 
    394   if (VLOG_IS_ON(3)) {
    395     std::vector<string> result_strings;
    396     absl::c_transform(result, std::back_inserter(result_strings),
    397                       [](ReshapePassthroughDimPair value) {
    398                         return absl::StrCat(value.result_dim, "->",
    399                                             value.operand_dim);
    400                       });
    401     VLOG(3) << "For a reshape from [" << StrJoin(operand_shape, ",") << "] to ["
    402             << StrJoin(result_shape, ",") << "] passthrough indices are ["
    403             << StrJoin(result_strings, ",")
    404             << "] (legend: `result`->`operand`)";
    405   }
    406 
    407   DCHECK(absl::c_is_sorted(
    408       result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) {
    409         return lhs.result_dim < rhs.result_dim;
    410       }));
    411 
    412   DCHECK(absl::c_is_sorted(
    413       result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) {
    414         return lhs.operand_dim < rhs.operand_dim;
    415       }));
    416 
    417   return result;
    418 }
    419 
    420 // Return true if `dim` is stated as an passthrough operand dim in
    421 // `passthrough_dims`.
    422 bool IsReshapePassthroughOperandDim(
    423     absl::Span<const ReshapePassthroughDimPair> passthrough_dims, int64 dim) {
    424   return absl::c_any_of(passthrough_dims,
    425                         [&](ReshapePassthroughDimPair passthrough_dim_pair) {
    426                           return passthrough_dim_pair.operand_dim == dim;
    427                         });
    428 }
    429 
    430 // Maps `operand_dim` which must be an passthrough operand dimension to its
    431 // corresponding passthrough result dimension based on `passthrough_dims`.
    432 int64 MapPassthroughOperandDimToResultDim(
    433     absl::Span<const ReshapePassthroughDimPair> passthrough_dims,
    434     int64 operand_dim) {
    435   auto it = absl::c_find_if(
    436       passthrough_dims, [&](ReshapePassthroughDimPair passthrough_dim_pair) {
    437         return passthrough_dim_pair.operand_dim == operand_dim;
    438       });
    439   CHECK(it != passthrough_dims.end());
    440   return it->result_dim;
    441 }
    442 
    443 int64 FindSourcePositionForPassthroughResultDim(
    444     absl::Span<const int64> operand_shape, absl::Span<const int64> result_shape,
    445     int64 source_passthrough_dim) {
    446   VLOG(3) << "FindSourcePositionForPassthroughResultDim(["
    447           << StrJoin(operand_shape, ",") << "], [" << StrJoin(result_shape, ",")
    448           << "], " << source_passthrough_dim << ")";
    449 
    450   int64 indexed_source_subarray_size =
    451       std::accumulate(operand_shape.begin() + source_passthrough_dim + 1,
    452                       operand_shape.end(), 1LL, std::multiplies<int64>());
    453 
    454   return FindSuffixWithProduct(result_shape, indexed_source_subarray_size);
    455 }
    456 
    457 Shape StripDegenerateDimensions(const Shape& shape) {
    458   DimensionVector new_dims;
    459   absl::c_copy_if(shape.dimensions(), std::back_inserter(new_dims),
    460                   [](int64 dim) { return dim != 1; });
    461   return ShapeUtil::MakeShape(shape.element_type(), new_dims);
    462 }
    463 };  // namespace
    464 
    465 StatusOr<ScalarIndexedArray*>
    466 IndexedArrayAnalysis::ReshapeToRemoveDegenerateDims(
    467     ScalarIndexedArray* operand) {
    468   const Shape& shape = operand->shape();
    469   if (!ShapeUtil::HasDegenerateDimensions(shape)) {
    470     return operand;
    471   }
    472 
    473   // We only need to reshape out the degenerate dims from the indices and the
    474   // source (except the source dim).
    475 
    476   const Shape& source_shape = operand->source()->shape();
    477   DimensionVector new_source_shape_dims;
    478   for (int64 i = 0, e = source_shape.dimensions_size(); i < e; i++) {
    479     if (i == operand->source_dim() || source_shape.dimensions(i) != 1) {
    480       new_source_shape_dims.push_back(source_shape.dimensions(i));
    481     }
    482   }
    483 
    484   Shape new_source_shape =
    485       ShapeUtil::MakeShape(shape.element_type(), new_source_shape_dims);
    486   Shape new_indices_shape =
    487       StripDegenerateDimensions(operand->indices()->shape());
    488 
    489   TF_ASSIGN_OR_RETURN(
    490       Array* const new_source,
    491       ComputeArrayForReshape(new_source_shape, operand->source()));
    492   TF_ASSIGN_OR_RETURN(
    493       Array* const new_indices,
    494       ComputeArrayForReshape(new_indices_shape, operand->indices()));
    495 
    496   // Build the new output dims while keeping track of the degenerate dims that
    497   // will no longer be present.
    498   DimensionVector new_output_dims;
    499   int64 degenerate_dims_seen = 0;
    500   for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) {
    501     if (shape.dimensions(i) == 1) {
    502       degenerate_dims_seen++;
    503     } else if (absl::c_linear_search(operand->output_dims(), i)) {
    504       new_output_dims.push_back(i - degenerate_dims_seen);
    505     }
    506   }
    507 
    508   // Similarly, build the new source dim while keeping track of the degenerate
    509   // dims that will no longer be present.
    510   int64 degenerate_dims_before_source_dim =
    511       std::count(source_shape.dimensions().begin(),
    512                  source_shape.dimensions().begin() + operand->source_dim(), 1);
    513   int64 new_source_dim =
    514       operand->source_dim() - degenerate_dims_before_source_dim;
    515 
    516   return ConstructScalarIndexedArray(
    517       new_source, new_indices, new_source_dim,
    518       InlinedVectorToVector(new_output_dims),
    519       StripDegenerateDimensions(operand->shape()));
    520 }
    521 
    522 StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::ReshapeToAddDegenerateDims(
    523     ScalarIndexedArray* operand, absl::Span<const int64> degenerate_dims) {
    524   if (degenerate_dims.empty()) {
    525     return operand;
    526   }
    527 
    528   CHECK(!ShapeUtil::HasDegenerateDimensions(operand->shape()));
    529 
    530   DimensionVector new_output_dims = [&]() {
    531     // To make things easy we use a "scratch" buffer of bools where the i'th
    532     // element is true iff the i'th component of the result index is an output
    533     // index.
    534 
    535     absl::InlinedVector<bool, 6> output_dims_bitvector(
    536         operand->shape().dimensions_size());
    537     for (int64 output_dim : operand->output_dims()) {
    538       output_dims_bitvector[output_dim] = true;
    539     }
    540 
    541     for (int64 degenerate_dim : degenerate_dims) {
    542       InsertAt(&output_dims_bitvector, degenerate_dim, false);
    543     }
    544 
    545     DimensionVector result;
    546     result.reserve(operand->output_dims().size());
    547     for (int64 i = 0, e = output_dims_bitvector.size(); i < e; i++) {
    548       if (output_dims_bitvector[i]) {
    549         result.push_back(i);
    550       }
    551     }
    552 
    553     return result;
    554   }();
    555 
    556   DimensionVector new_result_shape_dims;
    557   absl::c_copy(operand->shape().dimensions(),
    558                std::back_inserter(new_result_shape_dims));
    559   for (int64 degenerate_dim : degenerate_dims) {
    560     InsertAt(&new_result_shape_dims, degenerate_dim, 1);
    561   }
    562 
    563   DimensionVector new_source_shape_dims = new_result_shape_dims;
    564   for (int64 output_dim : new_output_dims) {
    565     EraseAt(&new_source_shape_dims, output_dim);
    566   }
    567 
    568   int64 new_source_dim = [&]() {
    569     for (int i = 0, e = new_source_shape_dims.size(); i < e; i++) {
    570       int64 non_degenerate_dims_seen = 0;
    571       if (non_degenerate_dims_seen == operand->source_dim()) {
    572         return i;
    573       }
    574       if (new_source_shape_dims[new_source_dim] != 1) {
    575         non_degenerate_dims_seen++;
    576       }
    577     }
    578     LOG(FATAL) << "Did not find source dim in " << ToString(operand);
    579   }();
    580 
    581   int64 source_dim_size =
    582       operand->source()->shape().dimensions(operand->source_dim());
    583   InsertAt(&new_source_shape_dims, /*index=*/new_source_dim,
    584            /*value=*/source_dim_size);
    585 
    586   Shape new_source_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
    587                                                 new_source_shape_dims);
    588   Shape new_result_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
    589                                                 new_result_shape_dims);
    590 
    591   TF_ASSIGN_OR_RETURN(
    592       Array* const new_source,
    593       ComputeArrayForReshape(new_source_shape, operand->source()));
    594   return ConstructScalarIndexedArray(
    595       new_source, operand->indices(), new_source_dim,
    596       InlinedVectorToVector(new_output_dims), new_result_shape);
    597 }
    598 
    599 StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::FoldReshapeOfGather(
    600     const Shape& shape, ScalarIndexedConstantArray* operand) {
    601   VLOG(3) << "FoldReshapeOfGather(" << ToString(operand) << ")";
    602 
    603   // To make things easier on ourselves, instead of directly trying to fold the
    604   // reshape of `operand` to `shape`, we call
    605   // `FoldReshapeOfGatherNoDegenerateDims` on shapes without degenerate dims and
    606   // handle the degenerate dimensions here by inserting reshapes.
    607 
    608   TF_ASSIGN_OR_RETURN(ScalarIndexedArray* const operand_without_degenerate_dims,
    609                       ReshapeToRemoveDegenerateDims(operand));
    610 
    611   Shape output_shape_without_degenerate_dims = StripDegenerateDimensions(shape);
    612   TF_ASSIGN_OR_RETURN(
    613       ScalarIndexedArray* const folded_reshape_without_degenerate_dims,
    614       FoldReshapeOfGatherNoDegenerateDims(
    615           output_shape_without_degenerate_dims,
    616           operand_without_degenerate_dims->as<ScalarIndexedConstantArray>()));
    617 
    618   if (folded_reshape_without_degenerate_dims == nullptr) {
    619     return nullptr;
    620   }
    621 
    622   DimensionVector degenerate_result_dims;
    623   for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) {
    624     if (shape.dimensions(i) == 1) {
    625       degenerate_result_dims.push_back(i);
    626     }
    627   }
    628 
    629   return ReshapeToAddDegenerateDims(folded_reshape_without_degenerate_dims,
    630                                     degenerate_result_dims);
    631 }
    632 
    633 StatusOr<ScalarIndexedArray*>
    634 IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims(
    635     const Shape& shape, ScalarIndexedConstantArray* scalar_indexed) {
    636   VLOG(3) << "FoldReshapeOfGatherNoDegenerateDims(" << ToString(scalar_indexed)
    637           << ")";
    638   CHECK(!ShapeUtil::HasDegenerateDimensions(shape));
    639   CHECK(!ShapeUtil::HasDegenerateDimensions(scalar_indexed->shape()));
    640 
    641   // Try to fold Reshape(ScalarIndexed(Const, Indices))
    642   //          => ScalarIndexed(Const', Indices)
    643   //
    644   // We can view the reshape and the scalar-indexed operations as functions that
    645   // map an output index (i.e. an index into the result) to an input index
    646   // (i.e. an index into the operand).  The key idea used here is that the
    647   // output-to-input mapping for some reshape operations may "pass through" some
    648   // output dimensions into the input space unchanged -- i.e. there may exist
    649   // output dimension "O" and input dimension "I" such that OutputIndex[O] is
    650   // always == InputIndexForReshape(OutputIndex)[I].  If these pass-through
    651   // dimensions in the input space of the reshape happen to be include all the
    652   // output dimensions for the scalar-indexed node then, roughly, the following
    653   // holds:
    654   //
    655   //    SourceIndexOfScalarIndexed(SourceIndexOfReshape(Idx))
    656   // == SourceIndexOfScalarIndexed(SourceIndexOfReshape(Ps ++ Qs))
    657   //
    658   //      Where Ps are the set of the pass-through components of Idx that are
    659   //      also the output dims of the scalar-indexed node, and Qs are the rest.
    660   //      For brevity, we're playing fast and loose with the notation here -- we
    661   //      don't literally require Idx to be a concatenation of Ps and Qs, as
    662   //      suggested by the "++".
    663   //
    664   // == SourceIndexOfScalarIndexed(Ps ++ SourceIndexOfReshape(Qs))
    665   //
    666   //      Again, we're playing fast and loose with the notation around "++".
    667   //      Generally this ++ will be a different function that the ++ in the
    668   //      previous step.
    669   //
    670   // If the scalar-indexed node has a constant as the source then the
    671   // SourceIndexOfReshape function can be "folded into" the constant itself by
    672   // reshaping it, leaving us with:
    673   //
    674   // == SourceIndexOfScalarIndexed(Ps ++ Qs)
    675   // == SourceIndexOfScalarIndexed(Idx)
    676   //
    677   // which is just a scalar-indexed node (with parameters different from the
    678   // scalar-indexed node we started with) with a reshaped constant as the
    679   // source.
    680   //
    681   // We can't fold SourceIndexOfReshape into the constant without introducing
    682   // another precondition: since the new scalar-indexed node will have a
    683   // reshaped (constant) array as its source it will, in general, have a
    684   // different source dimension than the original scalar-indexed node.  This
    685   // source dimension will have to be a passthrough dimension of the
    686   // SourceIndexOfReshape indexing function that is folded into the source. And
    687   // such a dimension need not exist so this is a non-trivial precondition.
    688 
    689   std::vector<ReshapePassthroughDimPair> reshape_passthrough_dims =
    690       ComputeReshapePassthroughDimPairs(
    691           /*operand_shape=*/AsInt64Slice(scalar_indexed->shape().dimensions()),
    692           /*result_shape=*/AsInt64Slice(shape.dimensions()));
    693 
    694   auto is_reshape_passthrough_operand_dim = [&](int64 operand_dim) {
    695     return IsReshapePassthroughOperandDim(reshape_passthrough_dims,
    696                                           operand_dim);
    697   };
    698 
    699   if (!absl::c_all_of(scalar_indexed->output_dims(),
    700                       is_reshape_passthrough_operand_dim)) {
    701     VLOG(3) << "Not all output dims are passthrough dims "
    702             << ToString(scalar_indexed);
    703     return nullptr;
    704   }
    705 
    706   // To compute the shape of the source for the new scalar-indexed node we're
    707   // going to create, we first "undo" the scalar-indexed operation.
    708   std::vector<int64> new_scalar_indexed_source_shape(shape.dimensions().begin(),
    709                                                      shape.dimensions().end());
    710   for (int64 i = scalar_indexed->output_dims().size() - 1; i >= 0; i--) {
    711     int64 output_dim = scalar_indexed->output_dims()[i];
    712     int64 output_dim_after_reshape = MapPassthroughOperandDimToResultDim(
    713         reshape_passthrough_dims, output_dim);
    714     EraseAt(&new_scalar_indexed_source_shape, output_dim_after_reshape);
    715   }
    716 
    717   // After this, we need to add in the dimension that will be the source
    718   // dimension for the new scalar-indexed node.  A scalar-indexed node "removes"
    719   // the source dimensions and "adds" the output dimensions, so to get back to
    720   // the shape for the *source* of the scalar-indexed node we need to remove the
    721   // output dims (which we did above) and then add back the source dim (which we
    722   // are about to do below):
    723 
    724   const Shape& scalar_indexed_source_shape = scalar_indexed->source()->shape();
    725 
    726   int64 source_dim_for_new_scalar_indexed_node =
    727       FindSourcePositionForPassthroughResultDim(
    728           /*operand_shape=*/AsInt64Slice(
    729               scalar_indexed_source_shape.dimensions()),
    730           /*result_shape=*/new_scalar_indexed_source_shape,
    731           scalar_indexed->source_dim());
    732 
    733   // We may not be able to find a source dim for the new scalar-indexed node.
    734   // For instance consider:
    735   //
    736   //   operand = s32[3,5,2] constant({...})
    737   //   indices = s32[7] parameter(0)
    738   //   gather = s32[3,2,7] gather(operand, indices),
    739   //       offset_dims={0,1},
    740   //       collapsed_slice_dims={1},
    741   //       start_index_map={1},
    742   //       index_vector_dim=1,
    743   //       slice_sizes={3,1,2}
    744   //   reshape = s32[6,7] reshape(gather)
    745   //
    746   // In this case the gather maps to:
    747   //    (scalar-indexed-const (constant s32[3,5,2]) %indices 1->[2])
    748   //
    749   // and the reshape passes through dimension 2 from its input into dimension 1
    750   // in its output.  However, we can't rewrite the reshape as a scalar-indexed
    751   // node because then we'd have to reshape the [3,5,2] `operand` array to
    752   // [6,5], but then dimension 1 of the reshaped [6,5] array indexes differently
    753   // (a.k.a. isn't pass-through) than the [3,5,2] array.
    754 
    755   if (source_dim_for_new_scalar_indexed_node == -1) {
    756     VLOG(3) << "Could not compute the source dim for the new scalar indexed "
    757                "node: scalar_indexed_source_shape = ["
    758             << StrJoin(scalar_indexed_source_shape.dimensions(), ",")
    759             << "] and new_scalar_indexed_source_shape = ["
    760             << StrJoin(new_scalar_indexed_source_shape, ",") << "]";
    761     return nullptr;
    762   }
    763 
    764   InsertAt(
    765       &new_scalar_indexed_source_shape, source_dim_for_new_scalar_indexed_node,
    766       scalar_indexed_source_shape.dimensions(scalar_indexed->source_dim()));
    767 
    768   CHECK_EQ(absl::c_accumulate(new_scalar_indexed_source_shape, 1LL,
    769                               std::multiplies<int64>()),
    770            ShapeUtil::ElementsIn(scalar_indexed_source_shape));
    771 
    772   CHECK(IsReshapePassthroughOperandDim(
    773       ComputeReshapePassthroughDimPairs(
    774           /*operand_shape=*/AsInt64Slice(
    775               scalar_indexed_source_shape.dimensions()),
    776           /*result_shape=*/new_scalar_indexed_source_shape),
    777       scalar_indexed->source_dim()));
    778 
    779   auto map_passthrough_operand_dim_to_result_dim = [&](int64 result_dim) {
    780     return MapPassthroughOperandDimToResultDim(reshape_passthrough_dims,
    781                                                result_dim);
    782   };
    783 
    784   std::vector<int64> output_dims_for_new_scalar_indexed_node;
    785   absl::c_transform(scalar_indexed->output_dims(),
    786                     std::back_inserter(output_dims_for_new_scalar_indexed_node),
    787                     map_passthrough_operand_dim_to_result_dim);
    788 
    789   TF_ASSIGN_OR_RETURN(const Literal* new_scalar_indexed_source_literal,
    790                       TakeOwnership(scalar_indexed->literal().Reshape(
    791                           new_scalar_indexed_source_shape)));
    792   TF_ASSIGN_OR_RETURN(
    793       Array * new_scalar_indexed_source,
    794       ComputeArrayForConstant(*new_scalar_indexed_source_literal));
    795 
    796   return ConstructScalarIndexedArray(
    797       new_scalar_indexed_source, scalar_indexed->indices(),
    798       source_dim_for_new_scalar_indexed_node,
    799       output_dims_for_new_scalar_indexed_node, shape);
    800 }
    801 
    802 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForReshape(
    803     const Shape& shape, Array* operand) {
    804   if (ShapeUtil::Compatible(operand->shape(), shape)) {
    805     return operand;
    806   }
    807 
    808   if (auto* scalar_indexed =
    809           dynamic_cast<ScalarIndexedConstantArray*>(operand)) {
    810     TF_ASSIGN_OR_RETURN(Analysis::Array * reshape_folded_into_gather,
    811                         FoldReshapeOfGather(shape, scalar_indexed));
    812     if (reshape_folded_into_gather) {
    813       return reshape_folded_into_gather;
    814     }
    815   }
    816 
    817   if (auto* constant_array = dynamic_cast<ConstantArray*>(operand)) {
    818     TF_ASSIGN_OR_RETURN(Literal* const new_literal,
    819                         TakeOwnership(constant_array->literal()->Reshape(
    820                             AsInt64Slice(shape.dimensions()))));
    821     return Construct<ConstantArray>(new_literal);
    822   }
    823 
    824   return Construct<ReshapedArray>(operand, shape);
    825 }
    826 
    827 StatusOr<Analysis::Array*>
    828 IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
    829                                                          Array* lhs,
    830                                                          Array* rhs) {
    831   // Try to fold BinaryOp(Broadcast(Const0), ScalarIndexed(Const1, Indices))
    832   //          => ScalarIndexed(BinaryOp(Broadcast'(Const0), Const1), Indices)
    833   //
    834   // We can do this if every output dimension from the scalar-indexed node is a
    835   // broadcasted dimension for the broadcast node.  Informally, the precondition
    836   // means Broadcast(Const0)[IDX] is solely a function of the components of IDX
    837   // that are not output-dims for the scalar-indexed node. In other words, for
    838   // every assignment to the non-output dims in IDX we have a "constant" LHS to
    839   // the BinaryOp.  This transform propagates this "constant" to the source for
    840   // the scalar-indexed node.
    841 
    842   ScalarIndexedConstantArray* lhs_scalar_indexed_const =
    843       dynamic_cast<ScalarIndexedConstantArray*>(lhs);
    844   ScalarIndexedConstantArray* rhs_scalar_indexed_const =
    845       dynamic_cast<ScalarIndexedConstantArray*>(rhs);
    846 
    847   bool lhs_is_indexed;
    848 
    849   // One of the operands must be scalar-indexed and the other must be a
    850   // broadcast of a constant.
    851   if (lhs_scalar_indexed_const && !rhs_scalar_indexed_const) {
    852     lhs_is_indexed = true;
    853   } else if (rhs_scalar_indexed_const && !lhs_scalar_indexed_const) {
    854     lhs_is_indexed = false;
    855   } else {
    856     return nullptr;
    857   }
    858 
    859   ScalarIndexedConstantArray* scalar_indexed_const =
    860       lhs_is_indexed ? lhs_scalar_indexed_const : rhs_scalar_indexed_const;
    861   UnknownArray* candidate_broadcast_array =
    862       dynamic_cast<UnknownArray*>(lhs_is_indexed ? rhs : lhs);
    863   if (!candidate_broadcast_array ||
    864       candidate_broadcast_array->instruction().opcode() !=
    865           HloOpcode::kBroadcast) {
    866     return nullptr;
    867   }
    868 
    869   const HloInstruction* broadcast_instr =
    870       &candidate_broadcast_array->instruction();
    871   const HloInstruction* broadcast_const_operand = broadcast_instr->operand(0);
    872   if (broadcast_const_operand->opcode() != HloOpcode::kConstant) {
    873     return nullptr;
    874   }
    875 
    876   absl::Span<const int64> broadcast_dims = broadcast_instr->dimensions();
    877   auto is_broadcasted_dim = [&](int64 output_dim) {
    878     return absl::c_find(broadcast_dims, output_dim) == broadcast_dims.end();
    879   };
    880 
    881   // All of the output dims must be "broadcasted" dims for the other operand.
    882   if (!absl::c_all_of(scalar_indexed_const->output_dims(),
    883                       is_broadcasted_dim)) {
    884     return nullptr;
    885   }
    886 
    887   // To figure out the broadcast dimensions for the (constant) source for the
    888   // scalar-indexed node, we "simulate" the index transformation done by the
    889   // existing broadcsat:
    890   enum class IndexComponent { Broadcasted, NotBroadcasted };
    891   std::vector<IndexComponent> simulated_index(
    892       broadcast_instr->shape().dimensions_size(), IndexComponent::Broadcasted);
    893   for (int64 broadcast_dim : broadcast_dims) {
    894     simulated_index[broadcast_dim] = IndexComponent::NotBroadcasted;
    895   }
    896 
    897   // The scalar-indexed node "removes" the source dim and "inserts" the output
    898   // dims.  We do the opposite here to undo the scalar-indexed operation.
    899   absl::Span<const int64> output_dims = scalar_indexed_const->output_dims();
    900   for (int64 i = output_dims.size() - 1; i >= 0; --i) {
    901     CHECK(simulated_index[output_dims[i]] == IndexComponent::Broadcasted);
    902     EraseAt(&simulated_index, output_dims[i]);
    903   }
    904 
    905   InsertAt(&simulated_index, scalar_indexed_const->source_dim(),
    906            IndexComponent::Broadcasted);
    907 
    908   // new_inner_broadcast_dims holds the broadcast dimensions for the inner
    909   // BinaryOp(Broadcast'(Const0), Const1).  We now translate simulated_index to
    910   // new_inner_broadcast_dims.
    911   std::vector<int64> new_inner_broadcast_dims;
    912   for (int64 i = 0; i < simulated_index.size(); i++) {
    913     if (simulated_index[i] == IndexComponent::NotBroadcasted) {
    914       new_inner_broadcast_dims.push_back(i);
    915     }
    916   }
    917 
    918   // inner_broadcast_result is the Broadcast'(Const0) bit in
    919   // BinaryOp(Broadcast'(Const0), Const1)
    920   TF_ASSIGN_OR_RETURN(
    921       Literal inner_broadcast_result,
    922       broadcast_const_operand->literal().Broadcast(
    923           scalar_indexed_const->source()->shape(), new_inner_broadcast_dims));
    924 
    925   // literal_for_new_source is BinaryOp(Broadcast'(Const0), Const1)
    926   const Literal* literal_for_new_source;
    927   if (lhs_is_indexed) {
    928     TF_ASSIGN_OR_RETURN(
    929         literal_for_new_source,
    930         TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
    931             opcode, scalar_indexed_const->literal(), inner_broadcast_result)));
    932   } else {
    933     TF_ASSIGN_OR_RETURN(
    934         literal_for_new_source,
    935         TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
    936             opcode, inner_broadcast_result, scalar_indexed_const->literal())));
    937   }
    938 
    939   ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
    940   return Construct<ScalarIndexedConstantArray>(
    941       new_source, scalar_indexed_const->indices(),
    942       scalar_indexed_const->source_dim(),
    943       std::vector<int64>(scalar_indexed_const->output_dims().begin(),
    944                          scalar_indexed_const->output_dims().end()),
    945       scalar_indexed_const->shape());
    946 }
    947 
    948 StatusOr<Analysis::Array*>
    949 IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(HloOpcode opcode,
    950                                                         Array* operand) {
    951   auto* scalar_indexed_const =
    952       dynamic_cast<ScalarIndexedConstantArray*>(operand);
    953   if (scalar_indexed_const == nullptr) {
    954     return nullptr;
    955   }
    956 
    957   // Fold UnaryOp(ScalarIndexed(Const, Indices))
    958   //   => ScalarIndexed(UnaryOp(Const), Indices)
    959 
    960   TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source,
    961                       TakeOwnership(HloEvaluator{}.EvaluateElementwiseUnaryOp(
    962                           opcode, scalar_indexed_const->literal())));
    963   ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
    964   return Construct<ScalarIndexedConstantArray>(
    965       new_source, scalar_indexed_const->indices(),
    966       scalar_indexed_const->source_dim(),
    967       ArraySliceToVector(scalar_indexed_const->output_dims()),
    968       scalar_indexed_const->shape());
    969 }
    970 
    971 namespace {
    972 
    973 // Returns the non-contracting non-batch dimension (as per `contracting_dims`
    974 // and `batch_dims`) if there is exactly one, otherwise returns nullopt.
    975 absl::optional<int64> GetOnlyNonContractingNonBatchDim(
    976     int64 rank, absl::Span<const int64> contracting_dims,
    977     absl::Span<const int64> batch_dims) {
    978   absl::optional<int64> result;
    979   for (int64 dim = 0; dim < rank; dim++) {
    980     if (!absl::c_linear_search(contracting_dims, dim) &&
    981         !absl::c_linear_search(batch_dims, dim)) {
    982       if (result.has_value()) {
    983         return absl::nullopt;
    984       }
    985       result = dim;
    986     }
    987   }
    988   return result;
    989 }
    990 
    991 // Returns true if `indexed_array`, which is either the LHS or the RHS of a Dot
    992 // HLO, can be folded into the dot operation.  For now these conditions are both
    993 // necessary and sufficient.
    994 //
    995 // `tag` describes the caller.  Used only for logging.
    996 //
    997 // `contracting_dims` and `batch_dims` are the contracting and batch dimensions
    998 // of whatever operand `indexed_array` is to the dot (LHS or RHS).
    999 bool CanFoldDotIntoIndexedArray(
   1000     absl::string_view tag, Analysis::ScalarIndexedConstantArray* indexed_array,
   1001     absl::Span<const int64> contracting_dims,
   1002     absl::Span<const int64> batch_dims) {
   1003   absl::optional<int64> non_contracting_non_batch_dim =
   1004       GetOnlyNonContractingNonBatchDim(indexed_array->shape().rank(),
   1005                                        contracting_dims, batch_dims);
   1006   if (!non_contracting_non_batch_dim.has_value()) {
   1007     VLOG(3) << tag << ": multiple or no non-contracting non-batch dimensions";
   1008     return false;
   1009   }
   1010 
   1011   if (indexed_array->output_dims().size() != 1 ||
   1012       indexed_array->output_dims()[0] != *non_contracting_non_batch_dim) {
   1013     VLOG(3) << tag << ": output dims != the lhs non-contracting non-batch dim";
   1014     return false;
   1015   }
   1016 
   1017   int64 indexed_array_rank = indexed_array->shape().rank();
   1018   if (indexed_array->source_dim() < (indexed_array_rank - 2)) {
   1019     // This restriction can be lifted by inserting reshape nodes.
   1020     VLOG(3) << tag
   1021             << ": source dim is not in the low two dims, won't be able to form "
   1022                "a matmul";
   1023     return false;
   1024   }
   1025 
   1026   return true;
   1027 }
   1028 
   1029 }  // namespace
   1030 
   1031 StatusOr<Analysis::Array*>
   1032 IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs(
   1033     const Shape& shape, const DotDimensionNumbers& dim_numbers,
   1034     const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs,
   1035     ConstantArray* rhs) {
   1036   VLOG(3) << "ComputeArrayForDotWithIndexedLhs(" << ToString(lhs) << " "
   1037           << ToString(rhs);
   1038   if (!CanFoldDotIntoIndexedArray(
   1039           "ComputeArrayForDotWithIndexedLhs", lhs, /*contracting_dims=*/
   1040           AsInt64Slice(dim_numbers.lhs_contracting_dimensions()),
   1041           /*batch_dims=*/AsInt64Slice(dim_numbers.lhs_batch_dimensions()))) {
   1042     return nullptr;
   1043   }
   1044 
   1045   int64 lhs_rank = lhs->shape().rank();
   1046   DotDimensionNumbers new_dim_numbers = dim_numbers;
   1047   new_dim_numbers.set_lhs_contracting_dimensions(
   1048       0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1));
   1049 
   1050   TF_ASSIGN_OR_RETURN(
   1051       Literal * literal_for_new_source,
   1052       TakeOwnership(HloEvaluator{}.EvaluateDotOp(
   1053           new_dim_numbers, precision_config, lhs->literal(), *rhs->literal())));
   1054 
   1055   // The new source dimension is wherever the non-batch non-contracting LHS
   1056   // dimension "went".
   1057   int64 new_source_dim = dim_numbers.lhs_batch_dimensions_size() +
   1058                          dim_numbers.rhs_batch_dimensions_size();
   1059 
   1060   ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
   1061   return Construct<ScalarIndexedConstantArray>(
   1062       new_source, lhs->indices(), new_source_dim,
   1063       ArraySliceToVector(lhs->output_dims()), shape);
   1064 }
   1065 
   1066 StatusOr<Analysis::Array*>
   1067 IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs(
   1068     const Shape& shape, const DotDimensionNumbers& dim_numbers,
   1069     const PrecisionConfig& precision_config, ConstantArray* lhs,
   1070     ScalarIndexedConstantArray* rhs) {
   1071   VLOG(3) << "ComputeArrayForDotWithIndexedRhs(" << ToString(lhs) << " "
   1072           << ToString(rhs);
   1073   if (!CanFoldDotIntoIndexedArray(
   1074           "ComputeArrayForDotWithIndexedRhs", rhs, /*contracting_dims=*/
   1075           AsInt64Slice(dim_numbers.rhs_contracting_dimensions()),
   1076           /*batch_dims=*/AsInt64Slice(dim_numbers.rhs_batch_dimensions()))) {
   1077     return nullptr;
   1078   }
   1079 
   1080   int64 rhs_rank = rhs->shape().rank();
   1081 
   1082   DotDimensionNumbers new_dim_numbers = dim_numbers;
   1083   new_dim_numbers.set_rhs_contracting_dimensions(
   1084       0, rhs->source_dim() == (rhs_rank - 1) ? (rhs_rank - 2) : (rhs_rank - 1));
   1085 
   1086   TF_ASSIGN_OR_RETURN(
   1087       Literal * literal_for_new_source,
   1088       TakeOwnership(HloEvaluator{}.EvaluateDotOp(
   1089           new_dim_numbers, precision_config, *lhs->literal(), rhs->literal())));
   1090 
   1091   // The new source dimension is wherever the non-batch non-contracting RHS
   1092   // dimension "went".
   1093   int64 new_source_dim = dim_numbers.lhs_batch_dimensions_size() +
   1094                          dim_numbers.rhs_batch_dimensions_size() + 1;
   1095 
   1096   ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
   1097   return Construct<ScalarIndexedConstantArray>(
   1098       new_source, rhs->indices(), new_source_dim,
   1099       ArraySliceToVector(rhs->output_dims()), shape);
   1100 }
   1101 
   1102 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot(
   1103     const Shape& shape, const DotDimensionNumbers& dim_numbers,
   1104     const PrecisionConfig& precision_config, Array* lhs, Array* rhs) {
   1105   // Intuitively, if
   1106   //
   1107   //  - The LHS of a dot product is a gathered sequence of rows from a constant
   1108   //    array (i.e. LHS[I,J] = Const[Indices[I],J]) and the RHS is a constant
   1109   //
   1110   //  OR
   1111   //
   1112   //  - If the RHS of a dot product is a gathered sequence of columns from a
   1113   //    constant array (i.e. RHS[I,J] = Const[I, Indices[J]]) and the LHS is a
   1114   //    constant
   1115   //
   1116   // then the result of the dot product itself is a gather from a constant
   1117   // array.  E.g. Dot(LHS, ConstRhs) where LHS[I,J] = Const[Indices[I],J] can be
   1118   // rewritten as Result where Result[I,J] = Dot(Const, ConstRhs)[Indices[I],
   1119   // J].
   1120   //
   1121   // We do a general version of this rewrite here.
   1122   VLOG(3) << "ComputeArrayForDot(" << ToString(lhs) << " " << ToString(rhs);
   1123   if (auto* lhs_indexed_array =
   1124           dynamic_cast<ScalarIndexedConstantArray*>(lhs)) {
   1125     if (auto* rhs_constant = dynamic_cast<ConstantArray*>(rhs)) {
   1126       return ComputeArrayForDotWithIndexedLhs(shape, dim_numbers,
   1127                                               precision_config,
   1128                                               lhs_indexed_array, rhs_constant);
   1129     }
   1130   }
   1131 
   1132   if (auto* rhs_indexed_array =
   1133           dynamic_cast<ScalarIndexedConstantArray*>(rhs)) {
   1134     if (auto* lhs_constant = dynamic_cast<ConstantArray*>(lhs)) {
   1135       return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers,
   1136                                               precision_config, lhs_constant,
   1137                                               rhs_indexed_array);
   1138     }
   1139   }
   1140 
   1141   return nullptr;
   1142 }
   1143 
   1144 absl::string_view IndexedArrayAnalysisPrinterPass::name() const {
   1145   return "indexed-array-analysis-printer-pass";
   1146 }
   1147 
   1148 StatusOr<bool> IndexedArrayAnalysisPrinterPass::Run(HloModule* module) {
   1149   if (!VLOG_IS_ON(2)) {
   1150     return false;
   1151   }
   1152 
   1153   IndexedArrayAnalysis analysis;
   1154   for (auto* computation : module->MakeNonfusionComputations()) {
   1155     for (auto* instr : computation->instructions()) {
   1156       TF_ASSIGN_OR_RETURN(Analysis::Array * t, analysis.GetArrayFor(instr));
   1157       if (!dynamic_cast<UnknownArray*>(t) && !dynamic_cast<ConstantArray*>(t)) {
   1158         VLOG(2) << instr->ToString() << "   ->   " << analysis.ToString(t);
   1159       }
   1160     }
   1161   }
   1162 
   1163   return false;
   1164 }
   1165 
   1166 }  // namespace xla
   1167