Home | History | Annotate | Download | only in lib
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/lib/triangular_solve.h"
     17 
     18 #include <memory>
     19 #include <vector>
     20 
     21 #include "tensorflow/compiler/tf2xla/lib/batch_dot.h"
     22 #include "tensorflow/compiler/tf2xla/lib/util.h"
     23 #include "tensorflow/compiler/xla/literal_util.h"
     24 #include "tensorflow/compiler/xla/shape_util.h"
     25 #include "tensorflow/compiler/xla/status_macros.h"
     26 #include "tensorflow/compiler/xla/statusor.h"
     27 #include "tensorflow/compiler/xla/util.h"
     28 #include "tensorflow/core/lib/core/errors.h"
     29 
     30 namespace tensorflow {
     31 
     32 xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
     33     xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a,
     34     xla::ComputationDataHandle b, bool left_side, bool lower, bool transpose_a,
     35     bool conjugate_a, int64 block_size) {
     36   TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> a_shape,
     37                       builder->GetShape(a));
     38   TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> b_shape,
     39                       builder->GetShape(b));
     40   if (xla::ShapeUtil::Rank(*a_shape) != xla::ShapeUtil::Rank(*b_shape)) {
     41     return errors::InvalidArgument(
     42         "Arguments to TriangularSolve have different ranks: ",
     43         xla::ShapeUtil::HumanString(*a_shape), " vs. ",
     44         xla::ShapeUtil::HumanString(*b_shape));
     45   }
     46   const int ndims = xla::ShapeUtil::Rank(*a_shape);
     47   if (ndims < 2) {
     48     return errors::InvalidArgument(
     49         "Arguments to TriangularSolve must have rank >= 2: ", ndims);
     50   }
     51   // The batch dimensions must be equal.
     52   std::vector<int64> batch_dimensions;
     53   for (int i = 0; i < ndims - 2; ++i) {
     54     int64 a_size = a_shape->dimensions(i);
     55     int64 b_size = b_shape->dimensions(i);
     56     if (a_size != b_size) {
     57       return errors::InvalidArgument(
     58           "Batch dimensions of arguments to TriangularSolve must be equal: ",
     59           xla::ShapeUtil::HumanString(*a_shape), " vs ",
     60           xla::ShapeUtil::HumanString(*b_shape));
     61     }
     62     batch_dimensions.push_back(a_size);
     63   }
     64 
     65   if (xla::ShapeUtil::GetDimension(*a_shape, -1) !=
     66       xla::ShapeUtil::GetDimension(*a_shape, -2)) {
     67     return errors::InvalidArgument(
     68         "The 'a' arguments to TriangularSolve must be square matrices: ",
     69         xla::ShapeUtil::HumanString(*a_shape));
     70   }
     71   const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2);
     72   const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1);
     73   if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(*a_shape, -1)) {
     74     return errors::InvalidArgument(
     75         "Arguments to TriangularSolve have incompatible matrix shapes: ",
     76         xla::ShapeUtil::HumanString(*a_shape), " vs ",
     77         xla::ShapeUtil::HumanString(*b_shape));
     78   }
     79 
     80   if (block_size < 1) {
     81     return errors::InvalidArgument(
     82         "block_size argument to TriangularSolve must be >= 1; got ",
     83         block_size);
     84   }
     85 
     86   // Returns [b1, b2, ... , bn, indices[0], indices[1]].
     87   auto prepend_batch_dims = [&](std::array<int64, 2> indices) {
     88     std::vector<int64> output(ndims);
     89     std::copy(batch_dimensions.begin(), batch_dimensions.end(), output.begin());
     90     std::copy(indices.begin(), indices.end(),
     91               output.begin() + batch_dimensions.size());
     92     return output;
     93   };
     94 
     95   // Applies a complex conjugation operation if `a` is complex and `conjugate_a`
     96   // is true, otherwise returns its argument.
     97   auto maybe_conj = [&](xla::ComputationBuilder* builder,
     98                         xla::ComputationDataHandle x) {
     99     auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a;
    100     return perform_conj ? builder->Conj(x) : x;
    101   };
    102 
    103   std::map<int, xla::Computation> base_computations;
    104   auto get_base_triangular_solve =
    105       [&](int k) -> xla::StatusOr<xla::Computation*> {
    106     xla::Computation& computation = base_computations[k];
    107     if (computation.IsNull()) {
    108       std::unique_ptr<xla::ComputationBuilder> sub = builder->CreateSubBuilder(
    109           tensorflow::strings::StrCat("trsm_base_", k));
    110 
    111       auto a_param =
    112           sub->Parameter(0,
    113                          xla::ShapeUtil::MakeShape(b_shape->element_type(),
    114                                                    prepend_batch_dims({k, k})),
    115                          "a");
    116 
    117       std::array<int64, 2> b_lastd;
    118       if (left_side) {
    119         b_lastd = {k, n};
    120       } else {
    121         b_lastd = {m, k};
    122       }
    123       auto b_param =
    124           sub->Parameter(1,
    125                          xla::ShapeUtil::MakeShape(b_shape->element_type(),
    126                                                    prepend_batch_dims(b_lastd)),
    127                          "b");
    128 
    129       // We use a left-looking subroutine on the block diagonal in some common
    130       // cases, while falling back to a recursive call in unsupported cases. The
    131       // left-looking subroutine is written with a While loop and so yields much
    132       // faster compile times. Moreover, the left-looking variant can give
    133       // higher performance on smaller (sub)problems.
    134       if (left_side && lower) {
    135         TF_RETURN_IF_ERROR(TriangularSolveLeftLooking(sub.get(), a_param,
    136                                                       b_param, transpose_a,
    137                                                       conjugate_a)
    138                                .status());
    139       } else {
    140         TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param,
    141                                            left_side, lower, transpose_a,
    142                                            conjugate_a,
    143                                            /*block_size=*/1)
    144                                .status());
    145       }
    146 
    147       TF_ASSIGN_OR_RETURN(computation, sub->Build());
    148     }
    149     return &computation;
    150   };
    151 
    152   xla::ComputationDataHandle output = Zeros(builder, *b_shape);
    153 
    154   // Right-looking blocked triangular solve.
    155   // For an explanation of the algorithm, see the TRSM discussion in:
    156   // Goto, Kazushige, and Robert Van De Geijn. "High-performance implementation
    157   // of the level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1
    158   // (2008): 4.
    159 
    160   // In the code comments below, T = lambda x: np.swapaxes(x, -1, -2) if
    161   // conjugate_a is False, or T = lambda x: np.conj(np.swapaxes(x, -1, -2)) if
    162   // conjugate_a is True.
    163 
    164   if (!left_side && lower == transpose_a) {
    165     // for i in range(0, a.shape[-1], block_size):
    166     for (int64 i = 0; i < n; i += block_size) {
    167       int64 k = std::min(block_size, n - i);
    168 
    169       // output[..., :, i:i+k] = triangular_solve(
    170       //     a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1)
    171       TF_ASSIGN_OR_RETURN(auto a_slice,
    172                           SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
    173       TF_ASSIGN_OR_RETURN(auto b_slice,
    174                           SliceInMinorDims(builder, b, {0, i}, {m, i + k}));
    175       xla::ComputationDataHandle update;
    176       if (k > 1) {
    177         TF_ASSIGN_OR_RETURN(xla::Computation * solve,
    178                             get_base_triangular_solve(k));
    179         update = builder->Call(*solve, {a_slice, b_slice});
    180       } else {
    181         update = builder->Div(b_slice, maybe_conj(builder, a_slice));
    182       }
    183       TF_ASSIGN_OR_RETURN(
    184           output, UpdateSliceInMinorDims(builder, output, update, {0, i}));
    185 
    186       // if i + k < a.shape[-1]:
    187       //   a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:]
    188       //   a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
    189       //   b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2)
    190       if (i + k < n) {
    191         xla::ComputationDataHandle a_slice_2;
    192         if (lower) {
    193           TF_ASSIGN_OR_RETURN(
    194               a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {n, i + k}));
    195         } else {
    196           TF_ASSIGN_OR_RETURN(
    197               a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, n}));
    198         }
    199 
    200         TF_ASSIGN_OR_RETURN(auto b_update,
    201                             BatchDot(builder, update, a_slice_2,
    202                                      /*transpose_x=*/false,
    203                                      /*transpose_y=*/transpose_a,
    204                                      /*conjugate_x=*/false,
    205                                      /*conjugate_y=*/conjugate_a));
    206         TF_ASSIGN_OR_RETURN(auto b_slice_2,
    207                             SliceInMinorDims(builder, b, {0, i + k}, {m, n}));
    208         b_update = builder->Sub(b_slice_2, b_update);
    209         TF_ASSIGN_OR_RETURN(
    210             b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k}));
    211       }
    212     }
    213 
    214   } else if (left_side && lower != transpose_a) {
    215     // for i in range(0, a.shape[-1], block_size):
    216     for (int64 i = 0; i < m; i += block_size) {
    217       int64 k = std::min(block_size, m - i);
    218 
    219       // output[..., i:i+k, :] = triangular_solve(
    220       //     a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1)
    221       TF_ASSIGN_OR_RETURN(auto a_slice,
    222                           SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
    223       TF_ASSIGN_OR_RETURN(auto b_slice,
    224                           SliceInMinorDims(builder, b, {i, 0}, {i + k, n}));
    225       xla::ComputationDataHandle update;
    226       if (k > 1) {
    227         TF_ASSIGN_OR_RETURN(xla::Computation * solve,
    228                             get_base_triangular_solve(k));
    229         update = builder->Call(*solve, {a_slice, b_slice});
    230       } else {
    231         update = builder->Div(b_slice, maybe_conj(builder, a_slice));
    232       }
    233       TF_ASSIGN_OR_RETURN(
    234           output, UpdateSliceInMinorDims(builder, output, update, {i, 0}));
    235 
    236       // if i + k < a.shape[-1]:
    237       //   a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:]
    238       //   a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
    239       //   b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :])
    240       if (i + k < m) {
    241         xla::ComputationDataHandle a_slice_2;
    242         if (lower) {
    243           TF_ASSIGN_OR_RETURN(
    244               a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {m, i + k}));
    245         } else {
    246           TF_ASSIGN_OR_RETURN(
    247               a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, m}));
    248         }
    249 
    250         TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update,
    251                                                     /*transpose_x=*/transpose_a,
    252                                                     /*transpose_y=*/false,
    253                                                     /*conjugate_x=*/conjugate_a,
    254                                                     /*conjugate_y=*/false));
    255         TF_ASSIGN_OR_RETURN(auto b_slice_2,
    256                             SliceInMinorDims(builder, b, {i + k, 0}, {m, n}));
    257         b_update = builder->Sub(b_slice_2, b_update);
    258         TF_ASSIGN_OR_RETURN(
    259             b, UpdateSliceInMinorDims(builder, b, b_update, {i + k, 0}));
    260       }
    261     }
    262   } else if (!left_side && lower != transpose_a) {
    263     // for i in reversed(range(0, a.shape[-1], block_size)):
    264     const int64 last_blk_ix = xla::RoundUpToNearest(n, block_size) - block_size;
    265     for (int64 i = last_blk_ix; i >= 0; i -= block_size) {
    266       int64 k = std::min(block_size, n - i);
    267 
    268       // output[..., :, i:i+k] triangular_solve(
    269       //     a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1)
    270       TF_ASSIGN_OR_RETURN(auto a_slice,
    271                           SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
    272       TF_ASSIGN_OR_RETURN(auto b_slice,
    273                           SliceInMinorDims(builder, b, {0, i}, {m, i + k}));
    274       xla::ComputationDataHandle update;
    275       if (k > 1) {
    276         TF_ASSIGN_OR_RETURN(xla::Computation * solve,
    277                             get_base_triangular_solve(k));
    278         update = builder->Call(*solve, {a_slice, b_slice});
    279       } else {
    280         update = builder->Div(b_slice, maybe_conj(builder, a_slice));
    281       }
    282       TF_ASSIGN_OR_RETURN(
    283           output, UpdateSliceInMinorDims(builder, output, update, {0, i}));
    284 
    285       // if i - k >= 0:
    286       //   a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k]
    287       //   a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
    288       //   b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2)
    289       if (i - k >= 0) {
    290         xla::ComputationDataHandle a_slice_2;
    291         if (lower) {
    292           TF_ASSIGN_OR_RETURN(a_slice_2,
    293                               SliceInMinorDims(builder, a, {i, 0}, {i + k, i}));
    294         } else {
    295           TF_ASSIGN_OR_RETURN(a_slice_2,
    296                               SliceInMinorDims(builder, a, {0, i}, {i, i + k}));
    297         }
    298 
    299         TF_ASSIGN_OR_RETURN(auto b_update,
    300                             BatchDot(builder, update, a_slice_2,
    301                                      /*transpose_x=*/false,
    302                                      /*transpose_y=*/transpose_a,
    303                                      /*conjugate_x=*/false,
    304                                      /*conjugate_y=*/conjugate_a));
    305         TF_ASSIGN_OR_RETURN(auto b_slice_2,
    306                             SliceInMinorDims(builder, b, {0, 0}, {m, i}));
    307         b_update = builder->Sub(b_slice_2, b_update);
    308         TF_ASSIGN_OR_RETURN(
    309             b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0}));
    310       }
    311     }
    312   } else {  // left_side && lower == transpose_a
    313     // for i in reversed(range(0, a.shape[-1], block_size)):
    314     const int64 last_blk_ix = xla::RoundUpToNearest(m, block_size) - block_size;
    315     for (int64 i = last_blk_ix; i >= 0; i -= block_size) {
    316       int64 k = std::min(block_size, m - i);
    317 
    318       // output[..., i:i+k, :] triangular_solve(
    319       //     a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1)
    320       TF_ASSIGN_OR_RETURN(auto a_slice,
    321                           SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
    322       TF_ASSIGN_OR_RETURN(auto b_slice,
    323                           SliceInMinorDims(builder, b, {i, 0}, {i + k, n}));
    324       xla::ComputationDataHandle update;
    325       if (k > 1) {
    326         TF_ASSIGN_OR_RETURN(xla::Computation * solve,
    327                             get_base_triangular_solve(k));
    328         update = builder->Call(*solve, {a_slice, b_slice});
    329       } else {
    330         update = builder->Div(b_slice, maybe_conj(builder, a_slice));
    331       }
    332       TF_ASSIGN_OR_RETURN(
    333           output, UpdateSliceInMinorDims(builder, output, update, {i, 0}));
    334 
    335       // if i - k >= 0:
    336       //   a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k]
    337       //   a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
    338       //   b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :])
    339       if (i - k >= 0) {
    340         xla::ComputationDataHandle a_slice_2;
    341         if (lower) {
    342           TF_ASSIGN_OR_RETURN(a_slice_2,
    343                               SliceInMinorDims(builder, a, {i, 0}, {i + k, i}));
    344         } else {
    345           TF_ASSIGN_OR_RETURN(a_slice_2,
    346                               SliceInMinorDims(builder, a, {0, i}, {i, i + k}));
    347         }
    348 
    349         TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update,
    350                                                     /*transpose_x=*/transpose_a,
    351                                                     /*transpose_y=*/false,
    352                                                     /*conjugate_x=*/conjugate_a,
    353                                                     /*conjugate_y=*/false));
    354         TF_ASSIGN_OR_RETURN(auto b_slice_2,
    355                             SliceInMinorDims(builder, b, {0, 0}, {i, n}));
    356         b_update = builder->Sub(b_slice_2, b_update);
    357         TF_ASSIGN_OR_RETURN(
    358             b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0}));
    359       }
    360     }
    361   }
    362 
    363   return output;
    364 }
    365 
    366 xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking(
    367     xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a,
    368     const xla::ComputationDataHandle& b, bool transpose_a, bool conjugate_a) {
    369   TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> a_shape,
    370                       builder->GetShape(a));
    371   TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> b_shape,
    372                       builder->GetShape(b));
    373   const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2);
    374   const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1);
    375   const int64 ndims = xla::ShapeUtil::Rank(*a_shape);
    376 
    377   std::vector<int64> batch_dimensions;
    378   for (int i = 0; i < ndims - 2; ++i) {
    379     int64 a_size = a_shape->dimensions(i);
    380     batch_dimensions.push_back(a_size);
    381   }
    382 
    383   auto prepend_batch_dims = [&](std::array<int64, 2> indices) {
    384     std::vector<int64> output(ndims);
    385     std::copy(batch_dimensions.begin(), batch_dimensions.end(), output.begin());
    386     std::copy(indices.begin(), indices.end(),
    387               output.begin() + batch_dimensions.size());
    388     return output;
    389   };
    390 
    391   auto maybe_conj = [&](xla::ComputationBuilder* builder,
    392                         xla::ComputationDataHandle x) {
    393     auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a;
    394     return perform_conj ? builder->Conj(x) : x;
    395   };
    396 
    397   // The main computation is performed in a While loop.
    398 
    399   // Allocate the output and set its first or last row,
    400   // output = np.zeros_like(b)
    401   // if transpose_a:
    402   //   output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:]
    403   // else:
    404   //   output[..., :1, :] = b[..., :1, :] / a[..., :1, :1]
    405   xla::ComputationDataHandle output = Zeros(builder, *b_shape);
    406   {
    407     auto i = transpose_a ? m - 1 : 0;
    408     TF_ASSIGN_OR_RETURN(auto a_slice,
    409                         SliceInMinorDims(builder, a, {i, i}, {i + 1, i + 1}));
    410     TF_ASSIGN_OR_RETURN(auto b_slice,
    411                         SliceInMinorDims(builder, b, {i, 0}, {i + 1, n}));
    412     auto update = builder->Div(b_slice, maybe_conj(builder, a_slice));
    413     TF_ASSIGN_OR_RETURN(
    414         output, UpdateSliceInMinorDims(builder, output, update, {i, 0}));
    415   }
    416 
    417   // Construct the initial loop carry tuple,
    418   // if transpose_a:
    419   //   init = (m-2, output, a, b)
    420   // else:
    421   //   init = (1, output, a, b)
    422   std::vector<xla::Shape> tuple_shapes = {
    423       // The loop iteration counter is a scalar, incremented each iteration.
    424       xla::ShapeUtil::MakeShape(xla::S32, {}),
    425       // The output has the shape of b, with one row updated each iteration.
    426       *b_shape,
    427       // The coefficient matrix a is a loop invariant.
    428       *a_shape,
    429       // The right-hand-side matrix b is a loop invariant.
    430       *b_shape};
    431   xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes);
    432   auto init_i = builder->ConstantR0<int32>(transpose_a ? m - 2 : 1);
    433   auto init = builder->Tuple({init_i, output, a, b});
    434 
    435   // Construct the loop condition function,
    436   // def cond_fun(loop_carry):
    437   //   i, output, a, b = loop_carry
    438   //   return i >= 0 if transpose_a else i < m
    439   std::unique_ptr<xla::ComputationBuilder> condb =
    440       builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond");
    441   {
    442     auto i = condb->GetTupleElement(
    443         condb->Parameter(0, tuple_shape,
    444                          "TriangularSolveLeftLookingWhileTuple"),
    445         0);
    446     if (transpose_a) {
    447       condb->Ge(i, condb->ConstantR0<int32>(0));
    448     } else {
    449       condb->Lt(i, condb->ConstantR0<int32>(m));
    450     }
    451   }
    452   TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
    453 
    454   // Construct the loop body function,
    455   // def body_fun(loop_carry):
    456   //   i, output, a, b = loop_carry
    457   //   if transpose_a:
    458   //     a_row = np.swapaxes(a[..., i+1:, i:i+1], -1 -2)
    459   //   else:
    460   //     a_row = a[..., i:i+1, :i]
    461   //   result_row = b[..., i:i+1, :] - np.matmul(a_row, output[..., :, :])
    462   //   output[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1]
    463   //   if transpose_a:
    464   //     return (i - 1, output, a, b)
    465   //   else:
    466   //     return (i + 1, output, a, b)
    467   // We have to do some extra FLOPs propagating zeros in the matrix multiply
    468   // because we can't have the size of its arguments depend on the loop counter.
    469   std::unique_ptr<xla::ComputationBuilder> bodyb =
    470       builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody");
    471   {
    472     auto input_tuple = bodyb->Parameter(0, tuple_shape,
    473                                         "TriangularSolveLeftLookingWhileTuple");
    474 
    475     // i, output, a, b = loop_carry
    476     auto i = bodyb->GetTupleElement(input_tuple, 0);
    477     auto body_out = bodyb->GetTupleElement(input_tuple, 1);
    478     auto body_a = bodyb->GetTupleElement(input_tuple, 2);
    479     auto body_b = bodyb->GetTupleElement(input_tuple, 3);
    480     auto zero = bodyb->ConstantR0<int32>(0);
    481 
    482     // Set up some helper functions.
    483     auto prepend_zeros = [&](std::array<xla::ComputationDataHandle, 2> starts) {
    484       auto zero = bodyb->Reshape(bodyb->ConstantR0<int32>(0), {1});
    485       std::vector<xla::ComputationDataHandle> padded_starts(ndims, zero);
    486       padded_starts[ndims - 2] = bodyb->Reshape(starts[0], {1});
    487       padded_starts[ndims - 1] = bodyb->Reshape(starts[1], {1});
    488       return bodyb->ConcatInDim(padded_starts, 0);
    489     };
    490 
    491     auto dynamic_slice = [&](xla::ComputationDataHandle x,
    492                              std::array<xla::ComputationDataHandle, 2> starts,
    493                              std::array<int64, 2> sizes) {
    494       auto padded_starts = prepend_zeros(starts);
    495       auto padded_sizes = prepend_batch_dims(sizes);
    496       return bodyb->DynamicSlice(x, padded_starts, padded_sizes);
    497     };
    498 
    499     auto update = [&](xla::ComputationDataHandle x,
    500                       xla::ComputationDataHandle update,
    501                       std::array<xla::ComputationDataHandle, 2> starts) {
    502       auto padded_starts = prepend_zeros(starts);
    503       return bodyb->DynamicUpdateSlice(x, update, padded_starts);
    504     };
    505 
    506     // We'd like to implement this:
    507     //   if transpose_a:
    508     //     a_row = T(a[..., i+1:, i:i+1])
    509     //     result_row = (b[..., i:i+1, :]
    510     //                   - np.matmul(a_row, body_out[..., i+1:, :]))
    511     //   else:
    512     //     result_row = (b[..., i:i+1, :]
    513     //                   - np.matmul(a[..., i:i+1, :i], body_out[..., :i, :]))
    514     // But since we can't have intermediate array sizes depend on the loop
    515     // counter, we instead exploit the fact that we initialized the output to
    516     // all zeros and use that as zero-padding (doing unnecessary FLOPs).
    517     xla::ComputationDataHandle a_row;
    518     if (transpose_a) {
    519       a_row = dynamic_slice(body_a, {zero, i}, {m, 1});
    520     } else {
    521       a_row = dynamic_slice(body_a, {i, zero}, {1, m});
    522     }
    523     TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), a_row, body_out,
    524                                                 /*transpose_x=*/transpose_a,
    525                                                 /*transpose_y=*/false,
    526                                                 /*conjugate_x=*/conjugate_a,
    527                                                 /*conjugate_y=*/false));
    528     auto result_row =
    529         bodyb->Sub(dynamic_slice(body_b, {i, zero}, {1, n}), b_update);
    530 
    531     // body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1]
    532     auto a_elt = dynamic_slice(body_a, {i, i}, {1, 1});
    533     auto div_result = bodyb->Div(result_row, maybe_conj(bodyb.get(), a_elt));
    534     body_out = update(body_out, div_result, {i, zero});
    535 
    536     // if transpose_a:
    537     //   return (i - 1, body_out, a, b)
    538     // else:
    539     //   return (i + 1, body_out, a, b)
    540     auto next_i = bodyb->Add(i, bodyb->ConstantR0<int32>(transpose_a ? -1 : 1));
    541     bodyb->Tuple({next_i, body_out, body_a, body_b});
    542   }
    543   TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
    544 
    545   // Construct the While loop and return the result,
    546   // return while_loop(cond_fun, body_fun, init)[1]
    547   auto triangular_solve_left_looking_while = builder->While(cond, body, init);
    548   return builder->GetTupleElement(triangular_solve_left_looking_while, 1);
    549 }
    550 
    551 }  // namespace tensorflow
    552