Home | History | Annotate | Download | only in ops
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/common_shape_fns.h"
     17 #include "tensorflow/core/framework/op.h"
     18 #include "tensorflow/core/framework/shape_inference.h"
     19 
     20 namespace tensorflow {
     21 
     22 using shape_inference::DimensionHandle;
     23 using shape_inference::InferenceContext;
     24 using shape_inference::ShapeHandle;
     25 
     26 namespace {
     27 
     28 // Return in <out> the result of making the end of <s> a square matrix.
     29 Status MakeBatchSquareMatrix(InferenceContext* c, ShapeHandle input,
     30                              ShapeHandle* out) {
     31   ShapeHandle s;
     32   TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 2, &s));
     33 
     34   DimensionHandle d;
     35   TF_RETURN_IF_ERROR(c->Merge(c->Dim(s, -2), c->Dim(s, -1), &d));
     36 
     37   ShapeHandle batch_shape;
     38   TF_RETURN_IF_ERROR(c->Subshape(s, 0, -2, &batch_shape));
     39   TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(d, d), out));
     40   return Status::OK();
     41 }
     42 
     43 Status BatchUnchangedSquareShapeFn(InferenceContext* c) {
     44   ShapeHandle out;
     45   TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &out));
     46   c->set_output(0, out);
     47   return Status::OK();
     48 }
     49 
     50 // The first input is [...,M,N] and second input is either [...,M,K] or [...,M].
     51 // Output is [...,N,K] or [...,N]. If <square>, then input is [...,M,M].
     52 Status MatrixSolveShapeFn(InferenceContext* c, bool square) {
     53   ShapeHandle lhs;
     54   ShapeHandle rhs;
     55   if (square) {
     56     TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &lhs));
     57   } else {
     58     TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &lhs));
     59   }
     60   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
     61 
     62   ShapeHandle lhs_batch_shape;
     63   ShapeHandle rhs_batch_shape;
     64   // Make the common batch subshape.
     65   TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
     66   TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
     67   // Make sure the batch dimensions match between lhs and rhs.
     68   TF_RETURN_IF_ERROR(
     69       c->Merge(lhs_batch_shape, rhs_batch_shape, &lhs_batch_shape));
     70 
     71   DimensionHandle m;
     72   // lhs and rhs have the same value for m to be compatible.
     73   TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, -2), c->Dim(rhs, -2), &m));
     74   DimensionHandle n = c->Dim(lhs, -1);
     75   if (square) {
     76     TF_RETURN_IF_ERROR(c->Merge(m, n, &n));
     77   }
     78 
     79   ShapeHandle out;
     80   // Build final shape (batch_shape + n + k) in <out>.
     81   TF_RETURN_IF_ERROR(c->Concatenate(lhs_batch_shape, c->Vector(n), &out));
     82   TF_RETURN_IF_ERROR(c->Concatenate(out, c->Vector(c->Dim(rhs, -1)), &out));
     83   c->set_output(0, out);
     84   return Status::OK();
     85 }
     86 
     87 // Input is [...,N,N]. Outputs are:
     88 //   [...,N];[0], if compute_v is false,
     89 //   [...,N];[...,N,N], if compute_v is true.
     90 Status SelfAdjointEigV2ShapeFn(InferenceContext* c) {
     91   ShapeHandle input;
     92   TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &input));
     93   DimensionHandle n;
     94   TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -2), c->Dim(input, -1), &n));
     95   ShapeHandle batch_shape;
     96   TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
     97   ShapeHandle e_shape;
     98   TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(n), &e_shape));
     99   c->set_output(0, e_shape);
    100   bool compute_v;
    101   TF_RETURN_IF_ERROR(c->GetAttr("compute_v", &compute_v));
    102   if (compute_v) {
    103     ShapeHandle v_shape;
    104     TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(n, n), &v_shape));
    105     c->set_output(1, v_shape);
    106   } else {
    107     c->set_output(1, c->Vector(0ll));
    108   }
    109   return Status::OK();
    110 }
    111 
    112 // Input is [...,N,N].
    113 // First and second outputs are:
    114 //   [...,N,N]; [...,N].
    115 Status LuShapeFn(InferenceContext* c) {
    116   ShapeHandle input;
    117   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
    118 
    119   DimensionHandle n;
    120   TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -2), c->Dim(input, -1), &n));
    121 
    122   ShapeHandle batch_shape;
    123   TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
    124 
    125   ShapeHandle lu_shape;
    126   ShapeHandle p_shape;
    127 
    128   TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(n, n), &lu_shape));
    129   TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(n), &p_shape));
    130 
    131   c->set_output(0, lu_shape);
    132   c->set_output(1, p_shape);
    133   return Status::OK();
    134 }
    135 
    136 // Input is [...,M,N].
    137 // First and second outputs are:
    138 //   [...,M,M]; [...,M,N], if full_matrices is true,
    139 //   [...,M,P]; [...,P,N], if full_matrices is false,
    140 // where P = min(M,N).
    141 Status QrShapeFn(InferenceContext* c) {
    142   ShapeHandle input;
    143   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
    144   DimensionHandle m = c->Dim(input, -2);
    145   DimensionHandle n = c->Dim(input, -1);
    146   DimensionHandle p;
    147   TF_RETURN_IF_ERROR(c->Min(m, n, &p));
    148   ShapeHandle batch_shape;
    149   TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
    150   ShapeHandle q_shape;
    151   ShapeHandle r_shape;
    152   bool full_matrices;
    153   TF_RETURN_IF_ERROR(c->GetAttr("full_matrices", &full_matrices));
    154   if (full_matrices) {
    155     TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, m), &q_shape));
    156     TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, n), &r_shape));
    157   } else {
    158     TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, p), &q_shape));
    159     TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(p, n), &r_shape));
    160   }
    161   c->set_output(0, q_shape);
    162   c->set_output(1, r_shape);
    163   return Status::OK();
    164 }
    165 
    166 // Input is [...,M,N].  First output is [...,min(M,N)].
    167 // Second and third outputs are:
    168 //   [0]; [0], if compute_uv is false.
    169 //   [...,M,M]; [...,N,N], if compute_uv is true and full_matrices is true,
    170 //   [...,M,P]; [...,N,P], if compute_uv is true and full_matrices is false,
    171 // where P = min(M,N).
    172 Status SvdShapeFn(InferenceContext* c) {
    173   ShapeHandle input;
    174   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
    175   DimensionHandle m = c->Dim(input, -2);
    176   DimensionHandle n = c->Dim(input, -1);
    177   DimensionHandle p;
    178   TF_RETURN_IF_ERROR(c->Min(m, n, &p));
    179   ShapeHandle batch_shape;
    180   TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
    181   ShapeHandle e_shape;
    182   TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(p), &e_shape));
    183   c->set_output(0, e_shape);
    184   bool compute_uv;
    185   TF_RETURN_IF_ERROR(c->GetAttr("compute_uv", &compute_uv));
    186   if (compute_uv) {
    187     ShapeHandle u_shape;
    188     ShapeHandle v_shape;
    189     bool full_matrices;
    190     TF_RETURN_IF_ERROR(c->GetAttr("full_matrices", &full_matrices));
    191     if (full_matrices) {
    192       TF_RETURN_IF_ERROR(
    193           c->Concatenate(batch_shape, c->Matrix(m, m), &u_shape));
    194       TF_RETURN_IF_ERROR(
    195           c->Concatenate(batch_shape, c->Matrix(n, n), &v_shape));
    196     } else {
    197       TF_RETURN_IF_ERROR(
    198           c->Concatenate(batch_shape, c->Matrix(m, p), &u_shape));
    199       TF_RETURN_IF_ERROR(
    200           c->Concatenate(batch_shape, c->Matrix(n, p), &v_shape));
    201     }
    202     c->set_output(1, u_shape);
    203     c->set_output(2, v_shape);
    204   } else {
    205     c->set_output(1, c->Vector(0ll));
    206     c->set_output(2, c->Vector(0ll));
    207   }
    208   return Status::OK();
    209 }
    210 
    211 // The first input is [...,3,M] and second input is [...,M,K].
    212 // Output is [...,M,K].
    213 Status TridiagonalSolveShapeFn(InferenceContext* c) {
    214   ShapeHandle lhs;
    215   ShapeHandle rhs;
    216   // Check that rank is at least 2.
    217   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &lhs));
    218   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
    219 
    220   // Extract batch dimensions and check they are the same.
    221   ShapeHandle lhs_batch_shape;
    222   ShapeHandle rhs_batch_shape;
    223   TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
    224   TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
    225   TF_RETURN_IF_ERROR(
    226       c->Merge(lhs_batch_shape, rhs_batch_shape, &lhs_batch_shape));
    227 
    228   // Check that "M" is the same in both inputs.
    229   DimensionHandle m_lhs = c->Dim(lhs, -1);
    230   DimensionHandle m_rhs = c->Dim(rhs, -2);
    231   TF_RETURN_IF_ERROR(c->Merge(m_lhs, m_rhs, &m_lhs));
    232 
    233   // Check that next-to-last dimension of the first input is 3.
    234   TF_RETURN_IF_ERROR(c->WithValue(c->Dim(lhs, -2), 3, &m_lhs));
    235 
    236   // The output shape is the same as rhs shape.
    237   c->set_output(0, rhs);
    238   return Status::OK();
    239 }
    240 
    241 }  // namespace
    242 
    243 REGISTER_OP("MatrixDeterminant")
    244     .Input("input: T")
    245     .Output("output: T")
    246     .Attr("T: {half, float, double, complex64, complex128}")
    247     .SetShapeFn([](InferenceContext* c) {
    248       ShapeHandle input;
    249       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
    250 
    251       DimensionHandle unused;
    252       TF_RETURN_IF_ERROR(
    253           c->Merge(c->Dim(input, -1), c->Dim(input, -2), &unused));
    254 
    255       ShapeHandle out;
    256       TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &out));
    257       c->set_output(0, out);
    258       return Status::OK();
    259     });
    260 
    261 REGISTER_OP("LogMatrixDeterminant")
    262     .Input("input: T")
    263     .Output("sign: T")
    264     .Output("log_abs_determinant: T")
    265     .Attr("T: {half, float, double, complex64, complex128}")
    266     .SetShapeFn([](InferenceContext* c) {
    267       ShapeHandle input;
    268       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
    269 
    270       DimensionHandle unused;
    271       TF_RETURN_IF_ERROR(
    272           c->Merge(c->Dim(input, -1), c->Dim(input, -2), &unused));
    273 
    274       ShapeHandle s;
    275       TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &s));
    276       c->set_output(0, s);
    277 
    278       ShapeHandle out;
    279       TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &out));
    280       c->set_output(1, out);
    281       return Status::OK();
    282     });
    283 
    284 REGISTER_OP("MatrixInverse")
    285     .Input("input: T")
    286     .Output("output: T")
    287     .Attr("adjoint: bool = False")
    288     .Attr("T: {double, float, half, complex64, complex128}")
    289     .SetShapeFn(BatchUnchangedSquareShapeFn);
    290 
    291 REGISTER_OP("MatrixExponential")
    292     .Deprecated(
    293         27, "Use Python implementation tf.linalg.matrix_exponential instead.")
    294     .Input("input: T")
    295     .Output("output: T")
    296     .Attr("T: {double, float, half, complex64, complex128}")
    297     .SetShapeFn(BatchUnchangedSquareShapeFn);
    298 
    299 REGISTER_OP("MatrixLogarithm")
    300     .Input("input: T")
    301     .Output("output: T")
    302     .Attr("T: {complex64, complex128}")
    303     .SetShapeFn(BatchUnchangedSquareShapeFn);
    304 
    305 REGISTER_OP("Cholesky")
    306     .Input("input: T")
    307     .Output("output: T")
    308     .Attr("T: {double, float, half, complex64, complex128}")
    309     .SetShapeFn(BatchUnchangedSquareShapeFn);
    310 
    311 REGISTER_OP("CholeskyGrad")
    312     .Input("l: T")
    313     .Input("grad: T")
    314     .Output("output: T")
    315     .Attr("T: {half, float, double}")
    316     .SetShapeFn(BatchUnchangedSquareShapeFn);
    317 
    318 REGISTER_OP("SelfAdjointEig")
    319     .Input("input: T")
    320     .Output("output: T")
    321     .Attr("T: {double, float, half}")
    322     .Deprecated(11, "Use SelfAdjointEigV2 instead.")
    323     .SetShapeFn([](InferenceContext* c) {
    324       ShapeHandle input;
    325       TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &input));
    326 
    327       DimensionHandle d = c->Dim(input, -1);
    328       DimensionHandle d_plus_1;
    329       TF_RETURN_IF_ERROR(c->Add(d, 1, &d_plus_1));
    330 
    331       ShapeHandle s;
    332       TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &s));
    333       TF_RETURN_IF_ERROR(c->Concatenate(s, c->Matrix(d_plus_1, d), &s));
    334       c->set_output(0, s);
    335       return Status::OK();
    336     });
    337 
    338 REGISTER_OP("SelfAdjointEigV2")
    339     .Input("input: T")
    340     .Output("e: T")
    341     .Output("v: T")
    342     .Attr("compute_v: bool = True")
    343     .Attr("T: {double, float, half, complex64, complex128}")
    344     .SetShapeFn(SelfAdjointEigV2ShapeFn);
    345 
    346 REGISTER_OP("Lu")
    347     .Input("input: T")
    348     .Output("lu: T")
    349     .Output("p: output_idx_type")
    350     .Attr("T: {double, float, half, complex64, complex128}")
    351     .Attr("output_idx_type: {int32, int64} = DT_INT32")
    352     .SetShapeFn(LuShapeFn);
    353 
    354 REGISTER_OP("MatrixSolve")
    355     .Input("matrix: T")
    356     .Input("rhs: T")
    357     .Output("output: T")
    358     .Attr("adjoint: bool = False")
    359     .Attr("T: {double, float, half, complex64, complex128}")
    360     .SetShapeFn([](InferenceContext* c) {
    361       return MatrixSolveShapeFn(c, true /* square (*/);
    362     });
    363 
    364 REGISTER_OP("MatrixTriangularSolve")
    365     .Input("matrix: T")
    366     .Input("rhs: T")
    367     .Output("output: T")
    368     .Attr("lower: bool = True")
    369     .Attr("adjoint: bool = False")
    370     .Attr("T: {double, float, half, complex64, complex128}")
    371     .SetShapeFn([](InferenceContext* c) {
    372       return MatrixSolveShapeFn(c, true /* square (*/);
    373     });
    374 
    375 REGISTER_OP("MatrixSolveLs")
    376     .Input("matrix: T")
    377     .Input("rhs: T")
    378     .Input("l2_regularizer: double")
    379     .Output("output: T")
    380     .Attr("T: {double, float, half, complex64, complex128}")
    381     .Attr("fast: bool = True")
    382     .SetShapeFn([](InferenceContext* c) {
    383       ShapeHandle l2_regularizer;
    384       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &l2_regularizer));
    385       return MatrixSolveShapeFn(c, false /* square */);
    386     });
    387 
    388 REGISTER_OP("MatrixSquareRoot")
    389     .Input("input: T")
    390     .Output("output: T")
    391     .Attr("T: {double, float, half, complex64, complex128}")
    392     .SetShapeFn(BatchUnchangedSquareShapeFn);
    393 
    394 REGISTER_OP("Qr")
    395     .Input("input: T")
    396     .Output("q: T")
    397     .Output("r: T")
    398     .Attr("full_matrices: bool = False")
    399     .Attr("T: {double, float, half, complex64, complex128}")
    400     .SetShapeFn(QrShapeFn);
    401 
    402 REGISTER_OP("Svd")
    403     .Input("input: T")
    404     .Output("s: T")
    405     .Output("u: T")
    406     .Output("v: T")
    407     .Attr("compute_uv: bool = True")
    408     .Attr("full_matrices: bool = False")
    409     .Attr("T: {double, float, half, complex64, complex128}")
    410     .SetShapeFn(SvdShapeFn);
    411 
    412 REGISTER_OP("TridiagonalSolve")
    413     .Input("diagonals: T")
    414     .Input("rhs: T")
    415     .Output("output: T")
    416     .Attr("T: {double, float, complex64, complex128}")
    417     .SetShapeFn(TridiagonalSolveShapeFn);
    418 
    419 // Deprecated op registrations:
    420 
    421 // Can be deleted after 3feb2017.
    422 REGISTER_OP("BatchSelfAdjointEig")
    423     .Input("input: T")
    424     .Output("output: T")
    425     .Attr("T: {double, float}")
    426     .Deprecated(11, "Use SelfAdjointEigV2 instead.")
    427     .SetShapeFn(shape_inference::UnknownShape);
    428 
    429 // Can all be deleted after 9mar2017.
    430 REGISTER_OP("BatchMatrixDeterminant")
    431     .Input("input: T")
    432     .Output("output: T")
    433     .Attr("T: {float, double, complex64, complex128}")
    434     .Deprecated(13, "Use MatrixDeterminant instead.")
    435     .SetShapeFn(shape_inference::UnknownShape);
    436 
    437 REGISTER_OP("BatchMatrixInverse")
    438     .Input("input: T")
    439     .Output("output: T")
    440     .Attr("adjoint: bool = False")
    441     .Attr("T: {double, float}")
    442     .Deprecated(13, "Use MatrixInverse instead.")
    443     .SetShapeFn(shape_inference::UnknownShape);
    444 
    445 REGISTER_OP("BatchCholesky")
    446     .Input("input: T")
    447     .Output("output: T")
    448     .Attr("T: {double, float}")
    449     .Deprecated(13, "Use Cholesky instead.")
    450     .SetShapeFn(shape_inference::UnknownShape);
    451 
    452 REGISTER_OP("BatchCholeskyGrad")
    453     .Input("l: T")
    454     .Input("grad: T")
    455     .Output("output: T")
    456     .Attr("T: {float, double}")
    457     .Deprecated(13, "Use CholeskyGrad instead.")
    458     .SetShapeFn(shape_inference::UnknownShape);
    459 
    460 REGISTER_OP("BatchSelfAdjointEigV2")
    461     .Input("input: T")
    462     .Output("e: T")
    463     .Output("v: T")
    464     .Attr("compute_v: bool = True")
    465     .Attr("T: {double, float}")
    466     .Deprecated(13, "Use SelfAdjointEigV2 instead.")
    467     .SetShapeFn(shape_inference::UnknownShape);
    468 
    469 REGISTER_OP("BatchMatrixSolve")
    470     .Input("matrix: T")
    471     .Input("rhs: T")
    472     .Output("output: T")
    473     .Attr("adjoint: bool = False")
    474     .Attr("T: {double, float}")
    475     .Deprecated(13, "Use MatrixSolve instead.")
    476     .SetShapeFn(shape_inference::UnknownShape);
    477 
    478 REGISTER_OP("BatchMatrixTriangularSolve")
    479     .Input("matrix: T")
    480     .Input("rhs: T")
    481     .Output("output: T")
    482     .Attr("lower: bool = True")
    483     .Attr("adjoint: bool = False")
    484     .Attr("T: {double, float}")
    485     .Deprecated(13, "Use MatrixTriangularSolve instead.")
    486     .SetShapeFn(shape_inference::UnknownShape);
    487 
    488 REGISTER_OP("BatchMatrixSolveLs")
    489     .Input("matrix: T")
    490     .Input("rhs: T")
    491     .Input("l2_regularizer: double")
    492     .Output("output: T")
    493     .Attr("T: {double, float}")
    494     .Attr("fast: bool = True")
    495     .Deprecated(13, "Use MatrixSolveLs instead.")
    496     .SetShapeFn(shape_inference::UnknownShape);
    497 
    498 REGISTER_OP("BatchSvd")
    499     .Input("input: T")
    500     .Output("s: T")
    501     .Output("u: T")
    502     .Output("v: T")
    503     .Attr("compute_uv: bool = True")
    504     .Attr("full_matrices: bool = False")
    505     .Attr("T: {double, float, complex64, complex128}")
    506     .Deprecated(13, "Use Svd instead.")
    507     .SetShapeFn(shape_inference::UnknownShape);
    508 
    509 }  // namespace tensorflow
    510