1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" 17 18 #include <memory> 19 #include <vector> 20 21 #include "tensorflow/compiler/tf2xla/lib/batch_dot.h" 22 #include "tensorflow/compiler/tf2xla/lib/util.h" 23 #include "tensorflow/compiler/xla/literal_util.h" 24 #include "tensorflow/compiler/xla/shape_util.h" 25 #include "tensorflow/compiler/xla/status_macros.h" 26 #include "tensorflow/compiler/xla/statusor.h" 27 #include "tensorflow/compiler/xla/util.h" 28 #include "tensorflow/core/lib/core/errors.h" 29 30 namespace tensorflow { 31 32 xla::StatusOr<xla::ComputationDataHandle> TriangularSolve( 33 xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, 34 xla::ComputationDataHandle b, bool left_side, bool lower, bool transpose_a, 35 bool conjugate_a, int64 block_size) { 36 TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> a_shape, 37 builder->GetShape(a)); 38 TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> b_shape, 39 builder->GetShape(b)); 40 if (xla::ShapeUtil::Rank(*a_shape) != xla::ShapeUtil::Rank(*b_shape)) { 41 return errors::InvalidArgument( 42 "Arguments to TriangularSolve have different ranks: ", 43 xla::ShapeUtil::HumanString(*a_shape), " vs. ", 44 xla::ShapeUtil::HumanString(*b_shape)); 45 } 46 const int ndims = xla::ShapeUtil::Rank(*a_shape); 47 if (ndims < 2) { 48 return errors::InvalidArgument( 49 "Arguments to TriangularSolve must have rank >= 2: ", ndims); 50 } 51 // The batch dimensions must be equal. 52 std::vector<int64> batch_dimensions; 53 for (int i = 0; i < ndims - 2; ++i) { 54 int64 a_size = a_shape->dimensions(i); 55 int64 b_size = b_shape->dimensions(i); 56 if (a_size != b_size) { 57 return errors::InvalidArgument( 58 "Batch dimensions of arguments to TriangularSolve must be equal: ", 59 xla::ShapeUtil::HumanString(*a_shape), " vs ", 60 xla::ShapeUtil::HumanString(*b_shape)); 61 } 62 batch_dimensions.push_back(a_size); 63 } 64 65 if (xla::ShapeUtil::GetDimension(*a_shape, -1) != 66 xla::ShapeUtil::GetDimension(*a_shape, -2)) { 67 return errors::InvalidArgument( 68 "The 'a' arguments to TriangularSolve must be square matrices: ", 69 xla::ShapeUtil::HumanString(*a_shape)); 70 } 71 const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2); 72 const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1); 73 if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(*a_shape, -1)) { 74 return errors::InvalidArgument( 75 "Arguments to TriangularSolve have incompatible matrix shapes: ", 76 xla::ShapeUtil::HumanString(*a_shape), " vs ", 77 xla::ShapeUtil::HumanString(*b_shape)); 78 } 79 80 if (block_size < 1) { 81 return errors::InvalidArgument( 82 "block_size argument to TriangularSolve must be >= 1; got ", 83 block_size); 84 } 85 86 // Returns [b1, b2, ... , bn, indices[0], indices[1]]. 87 auto prepend_batch_dims = [&](std::array<int64, 2> indices) { 88 std::vector<int64> output(ndims); 89 std::copy(batch_dimensions.begin(), batch_dimensions.end(), output.begin()); 90 std::copy(indices.begin(), indices.end(), 91 output.begin() + batch_dimensions.size()); 92 return output; 93 }; 94 95 // Applies a complex conjugation operation if `a` is complex and `conjugate_a` 96 // is true, otherwise returns its argument. 97 auto maybe_conj = [&](xla::ComputationBuilder* builder, 98 xla::ComputationDataHandle x) { 99 auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a; 100 return perform_conj ? builder->Conj(x) : x; 101 }; 102 103 std::map<int, xla::Computation> base_computations; 104 auto get_base_triangular_solve = 105 [&](int k) -> xla::StatusOr<xla::Computation*> { 106 xla::Computation& computation = base_computations[k]; 107 if (computation.IsNull()) { 108 std::unique_ptr<xla::ComputationBuilder> sub = builder->CreateSubBuilder( 109 tensorflow::strings::StrCat("trsm_base_", k)); 110 111 auto a_param = 112 sub->Parameter(0, 113 xla::ShapeUtil::MakeShape(b_shape->element_type(), 114 prepend_batch_dims({k, k})), 115 "a"); 116 117 std::array<int64, 2> b_lastd; 118 if (left_side) { 119 b_lastd = {k, n}; 120 } else { 121 b_lastd = {m, k}; 122 } 123 auto b_param = 124 sub->Parameter(1, 125 xla::ShapeUtil::MakeShape(b_shape->element_type(), 126 prepend_batch_dims(b_lastd)), 127 "b"); 128 129 // We use a left-looking subroutine on the block diagonal in some common 130 // cases, while falling back to a recursive call in unsupported cases. The 131 // left-looking subroutine is written with a While loop and so yields much 132 // faster compile times. Moreover, the left-looking variant can give 133 // higher performance on smaller (sub)problems. 134 if (left_side && lower) { 135 TF_RETURN_IF_ERROR(TriangularSolveLeftLooking(sub.get(), a_param, 136 b_param, transpose_a, 137 conjugate_a) 138 .status()); 139 } else { 140 TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param, 141 left_side, lower, transpose_a, 142 conjugate_a, 143 /*block_size=*/1) 144 .status()); 145 } 146 147 TF_ASSIGN_OR_RETURN(computation, sub->Build()); 148 } 149 return &computation; 150 }; 151 152 xla::ComputationDataHandle output = Zeros(builder, *b_shape); 153 154 // Right-looking blocked triangular solve. 155 // For an explanation of the algorithm, see the TRSM discussion in: 156 // Goto, Kazushige, and Robert Van De Geijn. "High-performance implementation 157 // of the level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1 158 // (2008): 4. 159 160 // In the code comments below, T = lambda x: np.swapaxes(x, -1, -2) if 161 // conjugate_a is False, or T = lambda x: np.conj(np.swapaxes(x, -1, -2)) if 162 // conjugate_a is True. 163 164 if (!left_side && lower == transpose_a) { 165 // for i in range(0, a.shape[-1], block_size): 166 for (int64 i = 0; i < n; i += block_size) { 167 int64 k = std::min(block_size, n - i); 168 169 // output[..., :, i:i+k] = triangular_solve( 170 // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1) 171 TF_ASSIGN_OR_RETURN(auto a_slice, 172 SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); 173 TF_ASSIGN_OR_RETURN(auto b_slice, 174 SliceInMinorDims(builder, b, {0, i}, {m, i + k})); 175 xla::ComputationDataHandle update; 176 if (k > 1) { 177 TF_ASSIGN_OR_RETURN(xla::Computation * solve, 178 get_base_triangular_solve(k)); 179 update = builder->Call(*solve, {a_slice, b_slice}); 180 } else { 181 update = builder->Div(b_slice, maybe_conj(builder, a_slice)); 182 } 183 TF_ASSIGN_OR_RETURN( 184 output, UpdateSliceInMinorDims(builder, output, update, {0, i})); 185 186 // if i + k < a.shape[-1]: 187 // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:] 188 // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 189 // b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2) 190 if (i + k < n) { 191 xla::ComputationDataHandle a_slice_2; 192 if (lower) { 193 TF_ASSIGN_OR_RETURN( 194 a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); 195 } else { 196 TF_ASSIGN_OR_RETURN( 197 a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, n})); 198 } 199 200 TF_ASSIGN_OR_RETURN(auto b_update, 201 BatchDot(builder, update, a_slice_2, 202 /*transpose_x=*/false, 203 /*transpose_y=*/transpose_a, 204 /*conjugate_x=*/false, 205 /*conjugate_y=*/conjugate_a)); 206 TF_ASSIGN_OR_RETURN(auto b_slice_2, 207 SliceInMinorDims(builder, b, {0, i + k}, {m, n})); 208 b_update = builder->Sub(b_slice_2, b_update); 209 TF_ASSIGN_OR_RETURN( 210 b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k})); 211 } 212 } 213 214 } else if (left_side && lower != transpose_a) { 215 // for i in range(0, a.shape[-1], block_size): 216 for (int64 i = 0; i < m; i += block_size) { 217 int64 k = std::min(block_size, m - i); 218 219 // output[..., i:i+k, :] = triangular_solve( 220 // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1) 221 TF_ASSIGN_OR_RETURN(auto a_slice, 222 SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); 223 TF_ASSIGN_OR_RETURN(auto b_slice, 224 SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); 225 xla::ComputationDataHandle update; 226 if (k > 1) { 227 TF_ASSIGN_OR_RETURN(xla::Computation * solve, 228 get_base_triangular_solve(k)); 229 update = builder->Call(*solve, {a_slice, b_slice}); 230 } else { 231 update = builder->Div(b_slice, maybe_conj(builder, a_slice)); 232 } 233 TF_ASSIGN_OR_RETURN( 234 output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); 235 236 // if i + k < a.shape[-1]: 237 // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:] 238 // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 239 // b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :]) 240 if (i + k < m) { 241 xla::ComputationDataHandle a_slice_2; 242 if (lower) { 243 TF_ASSIGN_OR_RETURN( 244 a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {m, i + k})); 245 } else { 246 TF_ASSIGN_OR_RETURN( 247 a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, m})); 248 } 249 250 TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update, 251 /*transpose_x=*/transpose_a, 252 /*transpose_y=*/false, 253 /*conjugate_x=*/conjugate_a, 254 /*conjugate_y=*/false)); 255 TF_ASSIGN_OR_RETURN(auto b_slice_2, 256 SliceInMinorDims(builder, b, {i + k, 0}, {m, n})); 257 b_update = builder->Sub(b_slice_2, b_update); 258 TF_ASSIGN_OR_RETURN( 259 b, UpdateSliceInMinorDims(builder, b, b_update, {i + k, 0})); 260 } 261 } 262 } else if (!left_side && lower != transpose_a) { 263 // for i in reversed(range(0, a.shape[-1], block_size)): 264 const int64 last_blk_ix = xla::RoundUpToNearest(n, block_size) - block_size; 265 for (int64 i = last_blk_ix; i >= 0; i -= block_size) { 266 int64 k = std::min(block_size, n - i); 267 268 // output[..., :, i:i+k] triangular_solve( 269 // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1) 270 TF_ASSIGN_OR_RETURN(auto a_slice, 271 SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); 272 TF_ASSIGN_OR_RETURN(auto b_slice, 273 SliceInMinorDims(builder, b, {0, i}, {m, i + k})); 274 xla::ComputationDataHandle update; 275 if (k > 1) { 276 TF_ASSIGN_OR_RETURN(xla::Computation * solve, 277 get_base_triangular_solve(k)); 278 update = builder->Call(*solve, {a_slice, b_slice}); 279 } else { 280 update = builder->Div(b_slice, maybe_conj(builder, a_slice)); 281 } 282 TF_ASSIGN_OR_RETURN( 283 output, UpdateSliceInMinorDims(builder, output, update, {0, i})); 284 285 // if i - k >= 0: 286 // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k] 287 // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 288 // b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2) 289 if (i - k >= 0) { 290 xla::ComputationDataHandle a_slice_2; 291 if (lower) { 292 TF_ASSIGN_OR_RETURN(a_slice_2, 293 SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); 294 } else { 295 TF_ASSIGN_OR_RETURN(a_slice_2, 296 SliceInMinorDims(builder, a, {0, i}, {i, i + k})); 297 } 298 299 TF_ASSIGN_OR_RETURN(auto b_update, 300 BatchDot(builder, update, a_slice_2, 301 /*transpose_x=*/false, 302 /*transpose_y=*/transpose_a, 303 /*conjugate_x=*/false, 304 /*conjugate_y=*/conjugate_a)); 305 TF_ASSIGN_OR_RETURN(auto b_slice_2, 306 SliceInMinorDims(builder, b, {0, 0}, {m, i})); 307 b_update = builder->Sub(b_slice_2, b_update); 308 TF_ASSIGN_OR_RETURN( 309 b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0})); 310 } 311 } 312 } else { // left_side && lower == transpose_a 313 // for i in reversed(range(0, a.shape[-1], block_size)): 314 const int64 last_blk_ix = xla::RoundUpToNearest(m, block_size) - block_size; 315 for (int64 i = last_blk_ix; i >= 0; i -= block_size) { 316 int64 k = std::min(block_size, m - i); 317 318 // output[..., i:i+k, :] triangular_solve( 319 // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1) 320 TF_ASSIGN_OR_RETURN(auto a_slice, 321 SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); 322 TF_ASSIGN_OR_RETURN(auto b_slice, 323 SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); 324 xla::ComputationDataHandle update; 325 if (k > 1) { 326 TF_ASSIGN_OR_RETURN(xla::Computation * solve, 327 get_base_triangular_solve(k)); 328 update = builder->Call(*solve, {a_slice, b_slice}); 329 } else { 330 update = builder->Div(b_slice, maybe_conj(builder, a_slice)); 331 } 332 TF_ASSIGN_OR_RETURN( 333 output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); 334 335 // if i - k >= 0: 336 // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k] 337 // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 338 // b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :]) 339 if (i - k >= 0) { 340 xla::ComputationDataHandle a_slice_2; 341 if (lower) { 342 TF_ASSIGN_OR_RETURN(a_slice_2, 343 SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); 344 } else { 345 TF_ASSIGN_OR_RETURN(a_slice_2, 346 SliceInMinorDims(builder, a, {0, i}, {i, i + k})); 347 } 348 349 TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update, 350 /*transpose_x=*/transpose_a, 351 /*transpose_y=*/false, 352 /*conjugate_x=*/conjugate_a, 353 /*conjugate_y=*/false)); 354 TF_ASSIGN_OR_RETURN(auto b_slice_2, 355 SliceInMinorDims(builder, b, {0, 0}, {i, n})); 356 b_update = builder->Sub(b_slice_2, b_update); 357 TF_ASSIGN_OR_RETURN( 358 b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0})); 359 } 360 } 361 } 362 363 return output; 364 } 365 366 xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking( 367 xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, 368 const xla::ComputationDataHandle& b, bool transpose_a, bool conjugate_a) { 369 TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> a_shape, 370 builder->GetShape(a)); 371 TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> b_shape, 372 builder->GetShape(b)); 373 const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2); 374 const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1); 375 const int64 ndims = xla::ShapeUtil::Rank(*a_shape); 376 377 std::vector<int64> batch_dimensions; 378 for (int i = 0; i < ndims - 2; ++i) { 379 int64 a_size = a_shape->dimensions(i); 380 batch_dimensions.push_back(a_size); 381 } 382 383 auto prepend_batch_dims = [&](std::array<int64, 2> indices) { 384 std::vector<int64> output(ndims); 385 std::copy(batch_dimensions.begin(), batch_dimensions.end(), output.begin()); 386 std::copy(indices.begin(), indices.end(), 387 output.begin() + batch_dimensions.size()); 388 return output; 389 }; 390 391 auto maybe_conj = [&](xla::ComputationBuilder* builder, 392 xla::ComputationDataHandle x) { 393 auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a; 394 return perform_conj ? builder->Conj(x) : x; 395 }; 396 397 // The main computation is performed in a While loop. 398 399 // Allocate the output and set its first or last row, 400 // output = np.zeros_like(b) 401 // if transpose_a: 402 // output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:] 403 // else: 404 // output[..., :1, :] = b[..., :1, :] / a[..., :1, :1] 405 xla::ComputationDataHandle output = Zeros(builder, *b_shape); 406 { 407 auto i = transpose_a ? m - 1 : 0; 408 TF_ASSIGN_OR_RETURN(auto a_slice, 409 SliceInMinorDims(builder, a, {i, i}, {i + 1, i + 1})); 410 TF_ASSIGN_OR_RETURN(auto b_slice, 411 SliceInMinorDims(builder, b, {i, 0}, {i + 1, n})); 412 auto update = builder->Div(b_slice, maybe_conj(builder, a_slice)); 413 TF_ASSIGN_OR_RETURN( 414 output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); 415 } 416 417 // Construct the initial loop carry tuple, 418 // if transpose_a: 419 // init = (m-2, output, a, b) 420 // else: 421 // init = (1, output, a, b) 422 std::vector<xla::Shape> tuple_shapes = { 423 // The loop iteration counter is a scalar, incremented each iteration. 424 xla::ShapeUtil::MakeShape(xla::S32, {}), 425 // The output has the shape of b, with one row updated each iteration. 426 *b_shape, 427 // The coefficient matrix a is a loop invariant. 428 *a_shape, 429 // The right-hand-side matrix b is a loop invariant. 430 *b_shape}; 431 xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); 432 auto init_i = builder->ConstantR0<int32>(transpose_a ? m - 2 : 1); 433 auto init = builder->Tuple({init_i, output, a, b}); 434 435 // Construct the loop condition function, 436 // def cond_fun(loop_carry): 437 // i, output, a, b = loop_carry 438 // return i >= 0 if transpose_a else i < m 439 std::unique_ptr<xla::ComputationBuilder> condb = 440 builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond"); 441 { 442 auto i = condb->GetTupleElement( 443 condb->Parameter(0, tuple_shape, 444 "TriangularSolveLeftLookingWhileTuple"), 445 0); 446 if (transpose_a) { 447 condb->Ge(i, condb->ConstantR0<int32>(0)); 448 } else { 449 condb->Lt(i, condb->ConstantR0<int32>(m)); 450 } 451 } 452 TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); 453 454 // Construct the loop body function, 455 // def body_fun(loop_carry): 456 // i, output, a, b = loop_carry 457 // if transpose_a: 458 // a_row = np.swapaxes(a[..., i+1:, i:i+1], -1 -2) 459 // else: 460 // a_row = a[..., i:i+1, :i] 461 // result_row = b[..., i:i+1, :] - np.matmul(a_row, output[..., :, :]) 462 // output[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] 463 // if transpose_a: 464 // return (i - 1, output, a, b) 465 // else: 466 // return (i + 1, output, a, b) 467 // We have to do some extra FLOPs propagating zeros in the matrix multiply 468 // because we can't have the size of its arguments depend on the loop counter. 469 std::unique_ptr<xla::ComputationBuilder> bodyb = 470 builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody"); 471 { 472 auto input_tuple = bodyb->Parameter(0, tuple_shape, 473 "TriangularSolveLeftLookingWhileTuple"); 474 475 // i, output, a, b = loop_carry 476 auto i = bodyb->GetTupleElement(input_tuple, 0); 477 auto body_out = bodyb->GetTupleElement(input_tuple, 1); 478 auto body_a = bodyb->GetTupleElement(input_tuple, 2); 479 auto body_b = bodyb->GetTupleElement(input_tuple, 3); 480 auto zero = bodyb->ConstantR0<int32>(0); 481 482 // Set up some helper functions. 483 auto prepend_zeros = [&](std::array<xla::ComputationDataHandle, 2> starts) { 484 auto zero = bodyb->Reshape(bodyb->ConstantR0<int32>(0), {1}); 485 std::vector<xla::ComputationDataHandle> padded_starts(ndims, zero); 486 padded_starts[ndims - 2] = bodyb->Reshape(starts[0], {1}); 487 padded_starts[ndims - 1] = bodyb->Reshape(starts[1], {1}); 488 return bodyb->ConcatInDim(padded_starts, 0); 489 }; 490 491 auto dynamic_slice = [&](xla::ComputationDataHandle x, 492 std::array<xla::ComputationDataHandle, 2> starts, 493 std::array<int64, 2> sizes) { 494 auto padded_starts = prepend_zeros(starts); 495 auto padded_sizes = prepend_batch_dims(sizes); 496 return bodyb->DynamicSlice(x, padded_starts, padded_sizes); 497 }; 498 499 auto update = [&](xla::ComputationDataHandle x, 500 xla::ComputationDataHandle update, 501 std::array<xla::ComputationDataHandle, 2> starts) { 502 auto padded_starts = prepend_zeros(starts); 503 return bodyb->DynamicUpdateSlice(x, update, padded_starts); 504 }; 505 506 // We'd like to implement this: 507 // if transpose_a: 508 // a_row = T(a[..., i+1:, i:i+1]) 509 // result_row = (b[..., i:i+1, :] 510 // - np.matmul(a_row, body_out[..., i+1:, :])) 511 // else: 512 // result_row = (b[..., i:i+1, :] 513 // - np.matmul(a[..., i:i+1, :i], body_out[..., :i, :])) 514 // But since we can't have intermediate array sizes depend on the loop 515 // counter, we instead exploit the fact that we initialized the output to 516 // all zeros and use that as zero-padding (doing unnecessary FLOPs). 517 xla::ComputationDataHandle a_row; 518 if (transpose_a) { 519 a_row = dynamic_slice(body_a, {zero, i}, {m, 1}); 520 } else { 521 a_row = dynamic_slice(body_a, {i, zero}, {1, m}); 522 } 523 TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), a_row, body_out, 524 /*transpose_x=*/transpose_a, 525 /*transpose_y=*/false, 526 /*conjugate_x=*/conjugate_a, 527 /*conjugate_y=*/false)); 528 auto result_row = 529 bodyb->Sub(dynamic_slice(body_b, {i, zero}, {1, n}), b_update); 530 531 // body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] 532 auto a_elt = dynamic_slice(body_a, {i, i}, {1, 1}); 533 auto div_result = bodyb->Div(result_row, maybe_conj(bodyb.get(), a_elt)); 534 body_out = update(body_out, div_result, {i, zero}); 535 536 // if transpose_a: 537 // return (i - 1, body_out, a, b) 538 // else: 539 // return (i + 1, body_out, a, b) 540 auto next_i = bodyb->Add(i, bodyb->ConstantR0<int32>(transpose_a ? -1 : 1)); 541 bodyb->Tuple({next_i, body_out, body_a, body_b}); 542 } 543 TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); 544 545 // Construct the While loop and return the result, 546 // return while_loop(cond_fun, body_fun, init)[1] 547 auto triangular_solve_left_looking_while = builder->While(cond, body, init); 548 return builder->GetTupleElement(triangular_solve_left_looking_while, 1); 549 } 550 551 } // namespace tensorflow 552