1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/common_shape_fns.h" 17 #include "tensorflow/core/framework/op.h" 18 #include "tensorflow/core/framework/shape_inference.h" 19 20 namespace tensorflow { 21 22 using shape_inference::DimensionHandle; 23 using shape_inference::InferenceContext; 24 using shape_inference::ShapeHandle; 25 26 namespace { 27 28 // Return in <out> the result of making the end of <s> a square matrix. 29 Status MakeBatchSquareMatrix(InferenceContext* c, ShapeHandle input, 30 ShapeHandle* out) { 31 ShapeHandle s; 32 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 2, &s)); 33 34 DimensionHandle d; 35 TF_RETURN_IF_ERROR(c->Merge(c->Dim(s, -2), c->Dim(s, -1), &d)); 36 37 ShapeHandle batch_shape; 38 TF_RETURN_IF_ERROR(c->Subshape(s, 0, -2, &batch_shape)); 39 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(d, d), out)); 40 return Status::OK(); 41 } 42 43 Status BatchUnchangedSquareShapeFn(InferenceContext* c) { 44 ShapeHandle out; 45 TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &out)); 46 c->set_output(0, out); 47 return Status::OK(); 48 } 49 50 // The first input is [...,M,N] and second input is either [...,M,K] or [...,M]. 51 // Output is [...,N,K] or [...,N]. If <square>, then input is [...,M,M]. 52 Status MatrixSolveShapeFn(InferenceContext* c, bool square) { 53 ShapeHandle lhs; 54 ShapeHandle rhs; 55 if (square) { 56 TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &lhs)); 57 } else { 58 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &lhs)); 59 } 60 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs)); 61 62 ShapeHandle lhs_batch_shape; 63 ShapeHandle rhs_batch_shape; 64 // Make the common batch subshape. 65 TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape)); 66 TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape)); 67 // Make sure the batch dimensions match between lhs and rhs. 68 TF_RETURN_IF_ERROR( 69 c->Merge(lhs_batch_shape, rhs_batch_shape, &lhs_batch_shape)); 70 71 DimensionHandle m; 72 // lhs and rhs have the same value for m to be compatible. 73 TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, -2), c->Dim(rhs, -2), &m)); 74 DimensionHandle n = c->Dim(lhs, -1); 75 if (square) { 76 TF_RETURN_IF_ERROR(c->Merge(m, n, &n)); 77 } 78 79 ShapeHandle out; 80 // Build final shape (batch_shape + n + k) in <out>. 81 TF_RETURN_IF_ERROR(c->Concatenate(lhs_batch_shape, c->Vector(n), &out)); 82 TF_RETURN_IF_ERROR(c->Concatenate(out, c->Vector(c->Dim(rhs, -1)), &out)); 83 c->set_output(0, out); 84 return Status::OK(); 85 } 86 87 // Input is [...,N,N]. Outputs are: 88 // [...,N];[0], if compute_v is false, 89 // [...,N];[...,N,N], if compute_v is true. 90 Status SelfAdjointEigV2ShapeFn(InferenceContext* c) { 91 ShapeHandle input; 92 TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &input)); 93 DimensionHandle n; 94 TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -2), c->Dim(input, -1), &n)); 95 ShapeHandle batch_shape; 96 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape)); 97 ShapeHandle e_shape; 98 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(n), &e_shape)); 99 c->set_output(0, e_shape); 100 bool compute_v; 101 TF_RETURN_IF_ERROR(c->GetAttr("compute_v", &compute_v)); 102 if (compute_v) { 103 ShapeHandle v_shape; 104 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(n, n), &v_shape)); 105 c->set_output(1, v_shape); 106 } else { 107 c->set_output(1, c->Vector(0ll)); 108 } 109 return Status::OK(); 110 } 111 112 // Input is [...,N,N]. 113 // First and second outputs are: 114 // [...,N,N]; [...,N]. 115 Status LuShapeFn(InferenceContext* c) { 116 ShapeHandle input; 117 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input)); 118 119 DimensionHandle n; 120 TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -2), c->Dim(input, -1), &n)); 121 122 ShapeHandle batch_shape; 123 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape)); 124 125 ShapeHandle lu_shape; 126 ShapeHandle p_shape; 127 128 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(n, n), &lu_shape)); 129 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(n), &p_shape)); 130 131 c->set_output(0, lu_shape); 132 c->set_output(1, p_shape); 133 return Status::OK(); 134 } 135 136 // Input is [...,M,N]. 137 // First and second outputs are: 138 // [...,M,M]; [...,M,N], if full_matrices is true, 139 // [...,M,P]; [...,P,N], if full_matrices is false, 140 // where P = min(M,N). 141 Status QrShapeFn(InferenceContext* c) { 142 ShapeHandle input; 143 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input)); 144 DimensionHandle m = c->Dim(input, -2); 145 DimensionHandle n = c->Dim(input, -1); 146 DimensionHandle p; 147 TF_RETURN_IF_ERROR(c->Min(m, n, &p)); 148 ShapeHandle batch_shape; 149 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape)); 150 ShapeHandle q_shape; 151 ShapeHandle r_shape; 152 bool full_matrices; 153 TF_RETURN_IF_ERROR(c->GetAttr("full_matrices", &full_matrices)); 154 if (full_matrices) { 155 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, m), &q_shape)); 156 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, n), &r_shape)); 157 } else { 158 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, p), &q_shape)); 159 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(p, n), &r_shape)); 160 } 161 c->set_output(0, q_shape); 162 c->set_output(1, r_shape); 163 return Status::OK(); 164 } 165 166 // Input is [...,M,N]. First output is [...,min(M,N)]. 167 // Second and third outputs are: 168 // [0]; [0], if compute_uv is false. 169 // [...,M,M]; [...,N,N], if compute_uv is true and full_matrices is true, 170 // [...,M,P]; [...,N,P], if compute_uv is true and full_matrices is false, 171 // where P = min(M,N). 172 Status SvdShapeFn(InferenceContext* c) { 173 ShapeHandle input; 174 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input)); 175 DimensionHandle m = c->Dim(input, -2); 176 DimensionHandle n = c->Dim(input, -1); 177 DimensionHandle p; 178 TF_RETURN_IF_ERROR(c->Min(m, n, &p)); 179 ShapeHandle batch_shape; 180 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape)); 181 ShapeHandle e_shape; 182 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(p), &e_shape)); 183 c->set_output(0, e_shape); 184 bool compute_uv; 185 TF_RETURN_IF_ERROR(c->GetAttr("compute_uv", &compute_uv)); 186 if (compute_uv) { 187 ShapeHandle u_shape; 188 ShapeHandle v_shape; 189 bool full_matrices; 190 TF_RETURN_IF_ERROR(c->GetAttr("full_matrices", &full_matrices)); 191 if (full_matrices) { 192 TF_RETURN_IF_ERROR( 193 c->Concatenate(batch_shape, c->Matrix(m, m), &u_shape)); 194 TF_RETURN_IF_ERROR( 195 c->Concatenate(batch_shape, c->Matrix(n, n), &v_shape)); 196 } else { 197 TF_RETURN_IF_ERROR( 198 c->Concatenate(batch_shape, c->Matrix(m, p), &u_shape)); 199 TF_RETURN_IF_ERROR( 200 c->Concatenate(batch_shape, c->Matrix(n, p), &v_shape)); 201 } 202 c->set_output(1, u_shape); 203 c->set_output(2, v_shape); 204 } else { 205 c->set_output(1, c->Vector(0ll)); 206 c->set_output(2, c->Vector(0ll)); 207 } 208 return Status::OK(); 209 } 210 211 // The first input is [...,3,M] and second input is [...,M,K]. 212 // Output is [...,M,K]. 213 Status TridiagonalSolveShapeFn(InferenceContext* c) { 214 ShapeHandle lhs; 215 ShapeHandle rhs; 216 // Check that rank is at least 2. 217 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &lhs)); 218 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs)); 219 220 // Extract batch dimensions and check they are the same. 221 ShapeHandle lhs_batch_shape; 222 ShapeHandle rhs_batch_shape; 223 TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape)); 224 TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape)); 225 TF_RETURN_IF_ERROR( 226 c->Merge(lhs_batch_shape, rhs_batch_shape, &lhs_batch_shape)); 227 228 // Check that "M" is the same in both inputs. 229 DimensionHandle m_lhs = c->Dim(lhs, -1); 230 DimensionHandle m_rhs = c->Dim(rhs, -2); 231 TF_RETURN_IF_ERROR(c->Merge(m_lhs, m_rhs, &m_lhs)); 232 233 // Check that next-to-last dimension of the first input is 3. 234 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(lhs, -2), 3, &m_lhs)); 235 236 // The output shape is the same as rhs shape. 237 c->set_output(0, rhs); 238 return Status::OK(); 239 } 240 241 } // namespace 242 243 REGISTER_OP("MatrixDeterminant") 244 .Input("input: T") 245 .Output("output: T") 246 .Attr("T: {half, float, double, complex64, complex128}") 247 .SetShapeFn([](InferenceContext* c) { 248 ShapeHandle input; 249 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input)); 250 251 DimensionHandle unused; 252 TF_RETURN_IF_ERROR( 253 c->Merge(c->Dim(input, -1), c->Dim(input, -2), &unused)); 254 255 ShapeHandle out; 256 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &out)); 257 c->set_output(0, out); 258 return Status::OK(); 259 }); 260 261 REGISTER_OP("LogMatrixDeterminant") 262 .Input("input: T") 263 .Output("sign: T") 264 .Output("log_abs_determinant: T") 265 .Attr("T: {half, float, double, complex64, complex128}") 266 .SetShapeFn([](InferenceContext* c) { 267 ShapeHandle input; 268 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input)); 269 270 DimensionHandle unused; 271 TF_RETURN_IF_ERROR( 272 c->Merge(c->Dim(input, -1), c->Dim(input, -2), &unused)); 273 274 ShapeHandle s; 275 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &s)); 276 c->set_output(0, s); 277 278 ShapeHandle out; 279 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &out)); 280 c->set_output(1, out); 281 return Status::OK(); 282 }); 283 284 REGISTER_OP("MatrixInverse") 285 .Input("input: T") 286 .Output("output: T") 287 .Attr("adjoint: bool = False") 288 .Attr("T: {double, float, half, complex64, complex128}") 289 .SetShapeFn(BatchUnchangedSquareShapeFn); 290 291 REGISTER_OP("MatrixExponential") 292 .Deprecated( 293 27, "Use Python implementation tf.linalg.matrix_exponential instead.") 294 .Input("input: T") 295 .Output("output: T") 296 .Attr("T: {double, float, half, complex64, complex128}") 297 .SetShapeFn(BatchUnchangedSquareShapeFn); 298 299 REGISTER_OP("MatrixLogarithm") 300 .Input("input: T") 301 .Output("output: T") 302 .Attr("T: {complex64, complex128}") 303 .SetShapeFn(BatchUnchangedSquareShapeFn); 304 305 REGISTER_OP("Cholesky") 306 .Input("input: T") 307 .Output("output: T") 308 .Attr("T: {double, float, half, complex64, complex128}") 309 .SetShapeFn(BatchUnchangedSquareShapeFn); 310 311 REGISTER_OP("CholeskyGrad") 312 .Input("l: T") 313 .Input("grad: T") 314 .Output("output: T") 315 .Attr("T: {half, float, double}") 316 .SetShapeFn(BatchUnchangedSquareShapeFn); 317 318 REGISTER_OP("SelfAdjointEig") 319 .Input("input: T") 320 .Output("output: T") 321 .Attr("T: {double, float, half}") 322 .Deprecated(11, "Use SelfAdjointEigV2 instead.") 323 .SetShapeFn([](InferenceContext* c) { 324 ShapeHandle input; 325 TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &input)); 326 327 DimensionHandle d = c->Dim(input, -1); 328 DimensionHandle d_plus_1; 329 TF_RETURN_IF_ERROR(c->Add(d, 1, &d_plus_1)); 330 331 ShapeHandle s; 332 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &s)); 333 TF_RETURN_IF_ERROR(c->Concatenate(s, c->Matrix(d_plus_1, d), &s)); 334 c->set_output(0, s); 335 return Status::OK(); 336 }); 337 338 REGISTER_OP("SelfAdjointEigV2") 339 .Input("input: T") 340 .Output("e: T") 341 .Output("v: T") 342 .Attr("compute_v: bool = True") 343 .Attr("T: {double, float, half, complex64, complex128}") 344 .SetShapeFn(SelfAdjointEigV2ShapeFn); 345 346 REGISTER_OP("Lu") 347 .Input("input: T") 348 .Output("lu: T") 349 .Output("p: output_idx_type") 350 .Attr("T: {double, float, half, complex64, complex128}") 351 .Attr("output_idx_type: {int32, int64} = DT_INT32") 352 .SetShapeFn(LuShapeFn); 353 354 REGISTER_OP("MatrixSolve") 355 .Input("matrix: T") 356 .Input("rhs: T") 357 .Output("output: T") 358 .Attr("adjoint: bool = False") 359 .Attr("T: {double, float, half, complex64, complex128}") 360 .SetShapeFn([](InferenceContext* c) { 361 return MatrixSolveShapeFn(c, true /* square (*/); 362 }); 363 364 REGISTER_OP("MatrixTriangularSolve") 365 .Input("matrix: T") 366 .Input("rhs: T") 367 .Output("output: T") 368 .Attr("lower: bool = True") 369 .Attr("adjoint: bool = False") 370 .Attr("T: {double, float, half, complex64, complex128}") 371 .SetShapeFn([](InferenceContext* c) { 372 return MatrixSolveShapeFn(c, true /* square (*/); 373 }); 374 375 REGISTER_OP("MatrixSolveLs") 376 .Input("matrix: T") 377 .Input("rhs: T") 378 .Input("l2_regularizer: double") 379 .Output("output: T") 380 .Attr("T: {double, float, half, complex64, complex128}") 381 .Attr("fast: bool = True") 382 .SetShapeFn([](InferenceContext* c) { 383 ShapeHandle l2_regularizer; 384 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &l2_regularizer)); 385 return MatrixSolveShapeFn(c, false /* square */); 386 }); 387 388 REGISTER_OP("MatrixSquareRoot") 389 .Input("input: T") 390 .Output("output: T") 391 .Attr("T: {double, float, half, complex64, complex128}") 392 .SetShapeFn(BatchUnchangedSquareShapeFn); 393 394 REGISTER_OP("Qr") 395 .Input("input: T") 396 .Output("q: T") 397 .Output("r: T") 398 .Attr("full_matrices: bool = False") 399 .Attr("T: {double, float, half, complex64, complex128}") 400 .SetShapeFn(QrShapeFn); 401 402 REGISTER_OP("Svd") 403 .Input("input: T") 404 .Output("s: T") 405 .Output("u: T") 406 .Output("v: T") 407 .Attr("compute_uv: bool = True") 408 .Attr("full_matrices: bool = False") 409 .Attr("T: {double, float, half, complex64, complex128}") 410 .SetShapeFn(SvdShapeFn); 411 412 REGISTER_OP("TridiagonalSolve") 413 .Input("diagonals: T") 414 .Input("rhs: T") 415 .Output("output: T") 416 .Attr("T: {double, float, complex64, complex128}") 417 .SetShapeFn(TridiagonalSolveShapeFn); 418 419 // Deprecated op registrations: 420 421 // Can be deleted after 3feb2017. 422 REGISTER_OP("BatchSelfAdjointEig") 423 .Input("input: T") 424 .Output("output: T") 425 .Attr("T: {double, float}") 426 .Deprecated(11, "Use SelfAdjointEigV2 instead.") 427 .SetShapeFn(shape_inference::UnknownShape); 428 429 // Can all be deleted after 9mar2017. 430 REGISTER_OP("BatchMatrixDeterminant") 431 .Input("input: T") 432 .Output("output: T") 433 .Attr("T: {float, double, complex64, complex128}") 434 .Deprecated(13, "Use MatrixDeterminant instead.") 435 .SetShapeFn(shape_inference::UnknownShape); 436 437 REGISTER_OP("BatchMatrixInverse") 438 .Input("input: T") 439 .Output("output: T") 440 .Attr("adjoint: bool = False") 441 .Attr("T: {double, float}") 442 .Deprecated(13, "Use MatrixInverse instead.") 443 .SetShapeFn(shape_inference::UnknownShape); 444 445 REGISTER_OP("BatchCholesky") 446 .Input("input: T") 447 .Output("output: T") 448 .Attr("T: {double, float}") 449 .Deprecated(13, "Use Cholesky instead.") 450 .SetShapeFn(shape_inference::UnknownShape); 451 452 REGISTER_OP("BatchCholeskyGrad") 453 .Input("l: T") 454 .Input("grad: T") 455 .Output("output: T") 456 .Attr("T: {float, double}") 457 .Deprecated(13, "Use CholeskyGrad instead.") 458 .SetShapeFn(shape_inference::UnknownShape); 459 460 REGISTER_OP("BatchSelfAdjointEigV2") 461 .Input("input: T") 462 .Output("e: T") 463 .Output("v: T") 464 .Attr("compute_v: bool = True") 465 .Attr("T: {double, float}") 466 .Deprecated(13, "Use SelfAdjointEigV2 instead.") 467 .SetShapeFn(shape_inference::UnknownShape); 468 469 REGISTER_OP("BatchMatrixSolve") 470 .Input("matrix: T") 471 .Input("rhs: T") 472 .Output("output: T") 473 .Attr("adjoint: bool = False") 474 .Attr("T: {double, float}") 475 .Deprecated(13, "Use MatrixSolve instead.") 476 .SetShapeFn(shape_inference::UnknownShape); 477 478 REGISTER_OP("BatchMatrixTriangularSolve") 479 .Input("matrix: T") 480 .Input("rhs: T") 481 .Output("output: T") 482 .Attr("lower: bool = True") 483 .Attr("adjoint: bool = False") 484 .Attr("T: {double, float}") 485 .Deprecated(13, "Use MatrixTriangularSolve instead.") 486 .SetShapeFn(shape_inference::UnknownShape); 487 488 REGISTER_OP("BatchMatrixSolveLs") 489 .Input("matrix: T") 490 .Input("rhs: T") 491 .Input("l2_regularizer: double") 492 .Output("output: T") 493 .Attr("T: {double, float}") 494 .Attr("fast: bool = True") 495 .Deprecated(13, "Use MatrixSolveLs instead.") 496 .SetShapeFn(shape_inference::UnknownShape); 497 498 REGISTER_OP("BatchSvd") 499 .Input("input: T") 500 .Output("s: T") 501 .Output("u: T") 502 .Output("v: T") 503 .Attr("compute_uv: bool = True") 504 .Attr("full_matrices: bool = False") 505 .Attr("T: {double, float, complex64, complex128}") 506 .Deprecated(13, "Use Svd instead.") 507 .SetShapeFn(shape_inference::UnknownShape); 508 509 } // namespace tensorflow 510