Home | History | Annotate | Download | only in ops
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      7     http://www.apache.org/licenses/LICENSE-2.0
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     16 #include "tensorflow/core/framework/common_shape_fns.h"
     17 #include "tensorflow/core/framework/op.h"
     18 #include "tensorflow/core/framework/shape_inference.h"
     20 namespace tensorflow {
     22 using shape_inference::DimensionHandle;
     23 using shape_inference::InferenceContext;
     24 using shape_inference::ShapeHandle;
     26 namespace {
     28 // Return in <out> the result of making the end of <s> a square matrix.
     29 Status MakeBatchSquareMatrix(InferenceContext* c, ShapeHandle input,
     30                              ShapeHandle* out) {
     31   ShapeHandle s;
     32   TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 2, &s));
     34   DimensionHandle d;
     35   TF_RETURN_IF_ERROR(c->Merge(c->Dim(s, -2), c->Dim(s, -1), &d));
     37   ShapeHandle batch_shape;
     38   TF_RETURN_IF_ERROR(c->Subshape(s, 0, -2, &batch_shape));
     39   TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(d, d), out));
     40   return Status::OK();
     41 }
     43 Status BatchUnchangedSquareShapeFn(InferenceContext* c) {
     44   ShapeHandle out;
     45   TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &out));
     46   c->set_output(0, out);
     47   return Status::OK();
     48 }
     50 // The first input is [...,M,N] and second input is either [...,M,K] or [...,M].
     51 // Output is [...,N,K] or [...,N]. If <square>, then input is [...,M,M].
     52 Status MatrixSolveShapeFn(InferenceContext* c, bool square) {
     53   ShapeHandle lhs;
     54   ShapeHandle rhs;
     55   if (square) {
     56     TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &lhs));
     57   } else {
     58     TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &lhs));
     59   }
     60   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
     62   ShapeHandle lhs_batch_shape;
     63   ShapeHandle rhs_batch_shape;
     64   // Make the common batch subshape.
     65   TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
     66   TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
     67   // Make sure the batch dimensions match between lhs and rhs.
     69       c->Merge(lhs_batch_shape, rhs_batch_shape, &lhs_batch_shape));
     71   DimensionHandle m;
     72   // lhs and rhs have the same value for m to be compatible.
     73   TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, -2), c->Dim(rhs, -2), &m));
     74   DimensionHandle n = c->Dim(lhs, -1);
     75   if (square) {
     76     TF_RETURN_IF_ERROR(c->Merge(m, n, &n));
     77   }
     79   ShapeHandle out;
     80   // Build final shape (batch_shape + n + k) in <out>.
     81   TF_RETURN_IF_ERROR(c->Concatenate(lhs_batch_shape, c->Vector(n), &out));
     82   TF_RETURN_IF_ERROR(c->Concatenate(out, c->Vector(c->Dim(rhs, -1)), &out));
     83   c->set_output(0, out);
     84   return Status::OK();
     85 }
     87 // Input is [...,N,N]. Outputs are:
     88 //   [...,N];[0], if compute_v is false,
     89 //   [...,N];[...,N,N], if compute_v is true.
     90 Status SelfAdjointEigV2ShapeFn(InferenceContext* c) {
     91   ShapeHandle input;
     92   TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &input));
     93   DimensionHandle n;
     94   TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -2), c->Dim(input, -1), &n));
     95   ShapeHandle batch_shape;
     96   TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
     97   ShapeHandle e_shape;
     98   TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(n), &e_shape));
     99   c->set_output(0, e_shape);
    100   bool compute_v;
    101   TF_RETURN_IF_ERROR(c->GetAttr("compute_v", &compute_v));
    102   if (compute_v) {
    103     ShapeHandle v_shape;
    104     TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(n, n), &v_shape));
    105     c->set_output(1, v_shape);
    106   } else {
    107     c->set_output(1, c->Vector(0ll));
    108   }
    109   return Status::OK();
    110 }
    112 // Input is [...,N,N].
    113 // First and second outputs are:
    114 //   [...,N,N]; [...,N].
    115 Status LuShapeFn(InferenceContext* c) {
    116   ShapeHandle input;
    117   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
    119   DimensionHandle n;
    120   TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -2), c->Dim(input, -1), &n));
    122   ShapeHandle batch_shape;
    123   TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
    125   ShapeHandle lu_shape;
    126   ShapeHandle p_shape;
    128   TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(n, n), &lu_shape));
    129   TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(n), &p_shape));
    131   c->set_output(0, lu_shape);
    132   c->set_output(1, p_shape);
    133   return Status::OK();
    134 }
    136 // Input is [...,M,N].
    137 // First and second outputs are:
    138 //   [...,M,M]; [...,M,N], if full_matrices is true,
    139 //   [...,M,P]; [...,P,N], if full_matrices is false,
    140 // where P = min(M,N).
    141 Status QrShapeFn(InferenceContext* c) {
    142   ShapeHandle input;
    143   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
    144   DimensionHandle m = c->Dim(input, -2);
    145   DimensionHandle n = c->Dim(input, -1);
    146   DimensionHandle p;
    147   TF_RETURN_IF_ERROR(c->Min(m, n, &p));
    148   ShapeHandle batch_shape;
    149   TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
    150   ShapeHandle q_shape;
    151   ShapeHandle r_shape;
    152   bool full_matrices;
    153   TF_RETURN_IF_ERROR(c->GetAttr("full_matrices", &full_matrices));
    154   if (full_matrices) {
    155     TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, m), &q_shape));
    156     TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, n), &r_shape));
    157   } else {
    158     TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, p), &q_shape));
    159     TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(p, n), &r_shape));
    160   }
    161   c->set_output(0, q_shape);
    162   c->set_output(1, r_shape);
    163   return Status::OK();
    164 }
    166 // Input is [...,M,N].  First output is [...,min(M,N)].
    167 // Second and third outputs are:
    168 //   [0]; [0], if compute_uv is false.
    169 //   [...,M,M]; [...,N,N], if compute_uv is true and full_matrices is true,
    170 //   [...,M,P]; [...,N,P], if compute_uv is true and full_matrices is false,
    171 // where P = min(M,N).
    172 Status SvdShapeFn(InferenceContext* c) {
    173   ShapeHandle input;
    174   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
    175   DimensionHandle m = c->Dim(input, -2);
    176   DimensionHandle n = c->Dim(input, -1);
    177   DimensionHandle p;
    178   TF_RETURN_IF_ERROR(c->Min(m, n, &p));
    179   ShapeHandle batch_shape;
    180   TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
    181   ShapeHandle e_shape;
    182   TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(p), &e_shape));
    183   c->set_output(0, e_shape);
    184   bool compute_uv;
    185   TF_RETURN_IF_ERROR(c->GetAttr("compute_uv", &compute_uv));
    186   if (compute_uv) {
    187     ShapeHandle u_shape;
    188     ShapeHandle v_shape;
    189     bool full_matrices;
    190     TF_RETURN_IF_ERROR(c->GetAttr("full_matrices", &full_matrices));
    191     if (full_matrices) {
    192       TF_RETURN_IF_ERROR(
    193           c->Concatenate(batch_shape, c->Matrix(m, m), &u_shape));
    194       TF_RETURN_IF_ERROR(
    195           c->Concatenate(batch_shape, c->Matrix(n, n), &v_shape));
    196     } else {
    197       TF_RETURN_IF_ERROR(
    198           c->Concatenate(batch_shape, c->Matrix(m, p), &u_shape));
    199       TF_RETURN_IF_ERROR(
    200           c->Concatenate(batch_shape, c->Matrix(n, p), &v_shape));
    201     }
    202     c->set_output(1, u_shape);
    203     c->set_output(2, v_shape);
    204   } else {
    205     c->set_output(1, c->Vector(0ll));
    206     c->set_output(2, c->Vector(0ll));
    207   }
    208   return Status::OK();
    209 }
    211 // The first input is [...,3,M] and second input is [...,M,K].
    212 // Output is [...,M,K].
    213 Status TridiagonalSolveShapeFn(InferenceContext* c) {
    214   ShapeHandle lhs;
    215   ShapeHandle rhs;
    216   // Check that rank is at least 2.
    217   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &lhs));
    218   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
    220   // Extract batch dimensions and check they are the same.
    221   ShapeHandle lhs_batch_shape;
    222   ShapeHandle rhs_batch_shape;
    223   TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
    224   TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
    226       c->Merge(lhs_batch_shape, rhs_batch_shape, &lhs_batch_shape));
    228   // Check that "M" is the same in both inputs.
    229   DimensionHandle m_lhs = c->Dim(lhs, -1);
    230   DimensionHandle m_rhs = c->Dim(rhs, -2);
    231   TF_RETURN_IF_ERROR(c->Merge(m_lhs, m_rhs, &m_lhs));
    233   // Check that next-to-last dimension of the first input is 3.
    234   TF_RETURN_IF_ERROR(c->WithValue(c->Dim(lhs, -2), 3, &m_lhs));
    236   // The output shape is the same as rhs shape.
    237   c->set_output(0, rhs);
    238   return Status::OK();
    239 }
    241 }  // namespace
    243 REGISTER_OP("MatrixDeterminant")
    244     .Input("input: T")
    245     .Output("output: T")
    246     .Attr("T: {half, float, double, complex64, complex128}")
    247     .SetShapeFn([](InferenceContext* c) {
    248       ShapeHandle input;
    249       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
    251       DimensionHandle unused;
    252       TF_RETURN_IF_ERROR(
    253           c->Merge(c->Dim(input, -1), c->Dim(input, -2), &unused));
    255       ShapeHandle out;
    256       TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &out));
    257       c->set_output(0, out);
    258       return Status::OK();
    259     });
    261 REGISTER_OP("LogMatrixDeterminant")
    262     .Input("input: T")
    263     .Output("sign: T")
    264     .Output("log_abs_determinant: T")
    265     .Attr("T: {half, float, double, complex64, complex128}")
    266     .SetShapeFn([](InferenceContext* c) {
    267       ShapeHandle input;
    268       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
    270       DimensionHandle unused;
    271       TF_RETURN_IF_ERROR(
    272           c->Merge(c->Dim(input, -1), c->Dim(input, -2), &unused));
    274       ShapeHandle s;
    275       TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &s));
    276       c->set_output(0, s);
    278       ShapeHandle out;
    279       TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &out));
    280       c->set_output(1, out);
    281       return Status::OK();
    282     });
    284 REGISTER_OP("MatrixInverse")
    285     .Input("input: T")
    286     .Output("output: T")
    287     .Attr("adjoint: bool = False")
    288     .Attr("T: {double, float, half, complex64, complex128}")
    289     .SetShapeFn(BatchUnchangedSquareShapeFn);
    291 REGISTER_OP("MatrixExponential")
    292     .Deprecated(
    293         27, "Use Python implementation tf.linalg.matrix_exponential instead.")
    294     .Input("input: T")
    295     .Output("output: T")
    296     .Attr("T: {double, float, half, complex64, complex128}")
    297     .SetShapeFn(BatchUnchangedSquareShapeFn);
    299 REGISTER_OP("MatrixLogarithm")
    300     .Input("input: T")
    301     .Output("output: T")
    302     .Attr("T: {complex64, complex128}")
    303     .SetShapeFn(BatchUnchangedSquareShapeFn);
    305 REGISTER_OP("Cholesky")
    306     .Input("input: T")
    307     .Output("output: T")
    308     .Attr("T: {double, float, half, complex64, complex128}")
    309     .SetShapeFn(BatchUnchangedSquareShapeFn);
    311 REGISTER_OP("CholeskyGrad")
    312     .Input("l: T")
    313     .Input("grad: T")
    314     .Output("output: T")
    315     .Attr("T: {half, float, double}")
    316     .SetShapeFn(BatchUnchangedSquareShapeFn);
    318 REGISTER_OP("SelfAdjointEig")
    319     .Input("input: T")
    320     .Output("output: T")
    321     .Attr("T: {double, float, half}")
    322     .Deprecated(11, "Use SelfAdjointEigV2 instead.")
    323     .SetShapeFn([](InferenceContext* c) {
    324       ShapeHandle input;
    325       TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &input));
    327       DimensionHandle d = c->Dim(input, -1);
    328       DimensionHandle d_plus_1;
    329       TF_RETURN_IF_ERROR(c->Add(d, 1, &d_plus_1));
    331       ShapeHandle s;
    332       TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &s));
    333       TF_RETURN_IF_ERROR(c->Concatenate(s, c->Matrix(d_plus_1, d), &s));
    334       c->set_output(0, s);
    335       return Status::OK();
    336     });
    338 REGISTER_OP("SelfAdjointEigV2")
    339     .Input("input: T")
    340     .Output("e: T")
    341     .Output("v: T")
    342     .Attr("compute_v: bool = True")
    343     .Attr("T: {double, float, half, complex64, complex128}")
    344     .SetShapeFn(SelfAdjointEigV2ShapeFn);
    346 REGISTER_OP("Lu")
    347     .Input("input: T")
    348     .Output("lu: T")
    349     .Output("p: output_idx_type")
    350     .Attr("T: {double, float, half, complex64, complex128}")
    351     .Attr("output_idx_type: {int32, int64} = DT_INT32")
    352     .SetShapeFn(LuShapeFn);
    354 REGISTER_OP("MatrixSolve")
    355     .Input("matrix: T")
    356     .Input("rhs: T")
    357     .Output("output: T")
    358     .Attr("adjoint: bool = False")
    359     .Attr("T: {double, float, half, complex64, complex128}")
    360     .SetShapeFn([](InferenceContext* c) {
    361       return MatrixSolveShapeFn(c, true /* square (*/);
    362     });
    364 REGISTER_OP("MatrixTriangularSolve")
    365     .Input("matrix: T")
    366     .Input("rhs: T")
    367     .Output("output: T")
    368     .Attr("lower: bool = True")
    369     .Attr("adjoint: bool = False")
    370     .Attr("T: {double, float, half, complex64, complex128}")
    371     .SetShapeFn([](InferenceContext* c) {
    372       return MatrixSolveShapeFn(c, true /* square (*/);
    373     });
    375 REGISTER_OP("MatrixSolveLs")
    376     .Input("matrix: T")
    377     .Input("rhs: T")
    378     .Input("l2_regularizer: double")
    379     .Output("output: T")
    380     .Attr("T: {double, float, half, complex64, complex128}")
    381     .Attr("fast: bool = True")
    382     .SetShapeFn([](InferenceContext* c) {
    383       ShapeHandle l2_regularizer;
    384       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &l2_regularizer));
    385       return MatrixSolveShapeFn(c, false /* square */);
    386     });
    388 REGISTER_OP("MatrixSquareRoot")
    389     .Input("input: T")
    390     .Output("output: T")
    391     .Attr("T: {double, float, half, complex64, complex128}")
    392     .SetShapeFn(BatchUnchangedSquareShapeFn);
    394 REGISTER_OP("Qr")
    395     .Input("input: T")
    396     .Output("q: T")
    397     .Output("r: T")
    398     .Attr("full_matrices: bool = False")
    399     .Attr("T: {double, float, half, complex64, complex128}")
    400     .SetShapeFn(QrShapeFn);
    402 REGISTER_OP("Svd")
    403     .Input("input: T")
    404     .Output("s: T")
    405     .Output("u: T")
    406     .Output("v: T")
    407     .Attr("compute_uv: bool = True")
    408     .Attr("full_matrices: bool = False")
    409     .Attr("T: {double, float, half, complex64, complex128}")
    410     .SetShapeFn(SvdShapeFn);
    412 REGISTER_OP("TridiagonalSolve")
    413     .Input("diagonals: T")
    414     .Input("rhs: T")
    415     .Output("output: T")
    416     .Attr("T: {double, float, complex64, complex128}")
    417     .SetShapeFn(TridiagonalSolveShapeFn);
    419 // Deprecated op registrations:
    421 // Can be deleted after 3feb2017.
    422 REGISTER_OP("BatchSelfAdjointEig")
    423     .Input("input: T")
    424     .Output("output: T")
    425     .Attr("T: {double, float}")
    426     .Deprecated(11, "Use SelfAdjointEigV2 instead.")
    427     .SetShapeFn(shape_inference::UnknownShape);
    429 // Can all be deleted after 9mar2017.
    430 REGISTER_OP("BatchMatrixDeterminant")
    431     .Input("input: T")
    432     .Output("output: T")
    433     .Attr("T: {float, double, complex64, complex128}")
    434     .Deprecated(13, "Use MatrixDeterminant instead.")
    435     .SetShapeFn(shape_inference::UnknownShape);
    437 REGISTER_OP("BatchMatrixInverse")
    438     .Input("input: T")
    439     .Output("output: T")
    440     .Attr("adjoint: bool = False")
    441     .Attr("T: {double, float}")
    442     .Deprecated(13, "Use MatrixInverse instead.")
    443     .SetShapeFn(shape_inference::UnknownShape);
    445 REGISTER_OP("BatchCholesky")
    446     .Input("input: T")
    447     .Output("output: T")
    448     .Attr("T: {double, float}")
    449     .Deprecated(13, "Use Cholesky instead.")
    450     .SetShapeFn(shape_inference::UnknownShape);
    452 REGISTER_OP("BatchCholeskyGrad")
    453     .Input("l: T")
    454     .Input("grad: T")
    455     .Output("output: T")
    456     .Attr("T: {float, double}")
    457     .Deprecated(13, "Use CholeskyGrad instead.")
    458     .SetShapeFn(shape_inference::UnknownShape);
    460 REGISTER_OP("BatchSelfAdjointEigV2")
    461     .Input("input: T")
    462     .Output("e: T")
    463     .Output("v: T")
    464     .Attr("compute_v: bool = True")
    465     .Attr("T: {double, float}")
    466     .Deprecated(13, "Use SelfAdjointEigV2 instead.")
    467     .SetShapeFn(shape_inference::UnknownShape);
    469 REGISTER_OP("BatchMatrixSolve")
    470     .Input("matrix: T")
    471     .Input("rhs: T")
    472     .Output("output: T")
    473     .Attr("adjoint: bool = False")
    474     .Attr("T: {double, float}")
    475     .Deprecated(13, "Use MatrixSolve instead.")
    476     .SetShapeFn(shape_inference::UnknownShape);
    478 REGISTER_OP("BatchMatrixTriangularSolve")
    479     .Input("matrix: T")
    480     .Input("rhs: T")
    481     .Output("output: T")
    482     .Attr("lower: bool = True")
    483     .Attr("adjoint: bool = False")
    484     .Attr("T: {double, float}")
    485     .Deprecated(13, "Use MatrixTriangularSolve instead.")
    486     .SetShapeFn(shape_inference::UnknownShape);
    488 REGISTER_OP("BatchMatrixSolveLs")
    489     .Input("matrix: T")
    490     .Input("rhs: T")
    491     .Input("l2_regularizer: double")
    492     .Output("output: T")
    493     .Attr("T: {double, float}")
    494     .Attr("fast: bool = True")
    495     .Deprecated(13, "Use MatrixSolveLs instead.")
    496     .SetShapeFn(shape_inference::UnknownShape);
    498 REGISTER_OP("BatchSvd")
    499     .Input("input: T")
    500     .Output("s: T")
    501     .Output("u: T")
    502     .Output("v: T")
    503     .Attr("compute_uv: bool = True")
    504     .Attr("full_matrices: bool = False")
    505     .Attr("T: {double, float, complex64, complex128}")
    506     .Deprecated(13, "Use Svd instead.")
    507     .SetShapeFn(shape_inference::UnknownShape);
    509 }  // namespace tensorflow