Home | History | Annotate | Download | only in ops
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <memory>
     17 #include <vector>
     18 
     19 #include "tensorflow/core/framework/function_testlib.h"
     20 #include "tensorflow/core/framework/tensor_testutil.h"
     21 #include "tensorflow/core/platform/test.h"
     22 #include "tensorflow/core/public/session.h"
     23 
     24 namespace tensorflow {
     25 namespace {
     26 
     27 namespace f = test::function;
     28 using FDH = FunctionDefHelper;
     29 
     30 std::unique_ptr<Session> NewSession() {
     31   SessionOptions opts;
     32   (*opts.config.mutable_device_count())["CPU"] = 1;
     33   return std::unique_ptr<Session>(NewSession(opts));
     34 }
     35 
     36 std::vector<Tensor> PackGrad(const Tensor& x0, const Tensor& x1,
     37                              const Tensor& dy, int axis) {
     38   auto T = DT_FLOAT;
     39   auto gdef = test::function::GDef(
     40       {f::NDef("x0", "Placeholder", {}, {{"dtype", T}}),
     41        f::NDef("x1", "Placeholder", {}, {{"dtype", T}}),
     42        f::NDef("axis", "Placeholder", {}, {{"dtype", DT_INT32}}),
     43        f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
     44        f::NDef("dx", "SymbolicGradient", {"x0", "x1", "dy"},
     45                {{"f", FDH::FunctionRef("Pack",
     46                                        {{"N", 2}, {"T", T}, {"axis", axis}})},
     47                 {"Tin", DataTypeSlice{T, T, T}},
     48                 {"Tout", DataTypeSlice{T, T}}})});
     49   VLOG(1) << DebugStringWhole(gdef);
     50   auto sess = NewSession();
     51   TF_CHECK_OK(sess->Create(gdef));
     52   std::vector<Tensor> out;
     53   TF_CHECK_OK(sess->Run({{"x0:0", x0},
     54                          {"x1:0", x1},
     55                          {"axis:0", test::AsScalar(axis)},
     56                          {"dy:0", dy}},
     57                         {"dx:0", "dx:1"}, {}, &out));
     58   CHECK_EQ(out.size(), 2);
     59   TF_CHECK_OK(sess->Close());
     60   return out;
     61 }
     62 
     63 TEST(ArrayGradTest, PackGrad) {
     64   Tensor x0(DT_FLOAT, {2, 3});
     65   x0.flat<float>().setZero();
     66   Tensor x1(DT_FLOAT, {2, 3});
     67   x1.flat<float>().setZero();
     68   Tensor dy(DT_FLOAT, {2, 2, 3});
     69   test::FillIota<float>(&dy, 0);
     70   auto dx = PackGrad(x0, x1, dy, 0);
     71   test::ExpectClose(dx[0],
     72                     test::AsTensor<float>({0., 1., 2., 3., 4., 5.}, {2, 3}));
     73   test::ExpectClose(dx[1],
     74                     test::AsTensor<float>({6., 7., 8., 9., 10., 11.}, {2, 3}));
     75 }
     76 
     77 std::vector<Tensor> UnpackGrad(const Tensor& x, const Tensor& dy0,
     78                                const Tensor& dy1, int axis) {
     79   auto T = DT_FLOAT;
     80   auto gdef = test::function::GDef(
     81       {f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
     82        f::NDef("axis", "Placeholder", {}, {{"dtype", DT_INT32}}),
     83        f::NDef("dy0", "Placeholder", {}, {{"dtype", T}}),
     84        f::NDef("dy1", "Placeholder", {}, {{"dtype", T}}),
     85        f::NDef("dx", "SymbolicGradient", {"x", "dy0", "dy1"},
     86                {{"f", FDH::FunctionRef("Unpack",
     87                                        {{"num", 2}, {"T", T}, {"axis", axis}})},
     88                 {"Tin", DataTypeSlice{T, T, T}},
     89                 {"Tout", DataTypeSlice{T}}})});
     90   VLOG(1) << DebugStringWhole(gdef);
     91   auto sess = NewSession();
     92   TF_CHECK_OK(sess->Create(gdef));
     93   std::vector<Tensor> out;
     94   TF_CHECK_OK(sess->Run({{"x:0", x},
     95                          {"axis:0", test::AsScalar(axis)},
     96                          {"dy0:0", dy0},
     97                          {"dy1:0", dy1}},
     98                         {"dx:0"}, {}, &out));
     99   CHECK_EQ(out.size(), 1);
    100   TF_CHECK_OK(sess->Close());
    101   return out;
    102 }
    103 
    104 TEST(ArrayGradTest, UnpackGrad) {
    105   Tensor x(DT_FLOAT, {2, 2, 3});
    106   x.flat<float>().setZero();
    107   Tensor dy0(DT_FLOAT, {2, 3});
    108   Tensor dy1(DT_FLOAT, {2, 3});
    109   test::FillIota<float>(&dy0, 0);
    110   test::FillIota<float>(&dy1, 100);
    111   auto dx = UnpackGrad(x, dy0, dy1, 0);
    112   test::ExpectClose(dx[0], test::AsTensor<float>({0., 1., 2., 3., 4., 5., 100.,
    113                                                   101., 102., 103., 104., 105.},
    114                                                  {2, 2, 3}));
    115 }
    116 
    117 std::vector<Tensor> ConcatGrad(int dim, const Tensor& x0, const Tensor& x1,
    118                                const Tensor& dy) {
    119   auto T = DT_FLOAT;
    120   auto gdef = test::function::GDef(
    121       {f::NDef("dim", "Placeholder", {}, {{"dtype", DT_INT32}}),
    122        f::NDef("x0", "Placeholder", {}, {{"dtype", T}}),
    123        f::NDef("x1", "Placeholder", {}, {{"dtype", T}}),
    124        f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
    125        f::NDef("dx", "SymbolicGradient", {"dim", "x0", "x1", "dy"},
    126                {{"f", FDH::FunctionRef("Concat", {{"N", 2}, {"T", T}})},
    127                 {"Tin", DataTypeSlice{DT_INT32, T, T, T}},
    128                 {"Tout", DataTypeSlice{DT_INT32, T, T}}})});
    129   VLOG(1) << DebugStringWhole(gdef);
    130   auto sess = NewSession();
    131   TF_CHECK_OK(sess->Create(gdef));
    132   std::vector<Tensor> out;
    133   TF_CHECK_OK(sess->Run(
    134       {{"dim", test::AsScalar(dim)}, {"x0:0", x0}, {"x1:0", x1}, {"dy:0", dy}},
    135       {"dx:0", "dx:1", "dx:2"}, {}, &out));
    136   CHECK_EQ(out.size(), 3);
    137   TF_CHECK_OK(sess->Close());
    138   return out;
    139 }
    140 
    141 std::vector<Tensor> ConcatGradV2(int dim, const Tensor& x0, const Tensor& x1,
    142                                  const Tensor& dy) {
    143   auto T = DT_FLOAT;
    144   auto gdef = test::function::GDef(
    145       {f::NDef("x0", "Placeholder", {}, {{"dtype", T}}),
    146        f::NDef("x1", "Placeholder", {}, {{"dtype", T}}),
    147        f::NDef("dim", "Placeholder", {}, {{"dtype", DT_INT32}}),
    148        f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
    149        f::NDef("dx", "SymbolicGradient", {"x0", "x1", "dim", "dy"},
    150                {{"f", FDH::FunctionRef("ConcatV2", {{"N", 2}, {"T", T}})},
    151                 {"Tin", DataTypeSlice{T, T, DT_INT32, T}},
    152                 {"Tout", DataTypeSlice{T, T, DT_INT32}}})});
    153   VLOG(1) << DebugStringWhole(gdef);
    154   auto sess = NewSession();
    155   TF_CHECK_OK(sess->Create(gdef));
    156   std::vector<Tensor> out;
    157   TF_CHECK_OK(sess->Run(
    158       {{"x0:0", x0}, {"x1:0", x1}, {"dim", test::AsScalar(dim)}, {"dy:0", dy}},
    159       {"dx:0", "dx:1", "dx:2"}, {}, &out));
    160   CHECK_EQ(out.size(), 3);
    161   TF_CHECK_OK(sess->Close());
    162   return out;
    163 }
    164 
    165 TEST(ArrayGradTest, ConcatGrad) {
    166   Tensor x0(DT_FLOAT, {2, 3, 5});
    167   x0.flat<float>().setZero();
    168   Tensor x1(DT_FLOAT, {2, 1, 5});
    169   x1.flat<float>().setZero();
    170   Tensor dy(DT_FLOAT, {2, 4, 5});
    171   test::FillIota<float>(&dy, 0);
    172 
    173   // Test Concat.
    174   auto dx = ConcatGrad(1, x0, x1, dy);
    175   test::ExpectTensorEqual<int32>(dx[0], test::AsScalar(0));
    176   test::ExpectClose(
    177       dx[1],
    178       test::AsTensor<float>({0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.,
    179                              10., 11., 12., 13., 14., 20., 21., 22., 23., 24.,
    180                              25., 26., 27., 28., 29., 30., 31., 32., 33., 34.},
    181                             {2, 3, 5}));
    182   test::ExpectClose(dx[2], test::AsTensor<float>({15., 16., 17., 18., 19., 35.,
    183                                                   36., 37., 38., 39.},
    184                                                  {2, 1, 5}));
    185 
    186   // Test ConcatV2 with positive concat axis.
    187   dx = ConcatGradV2(1, x0, x1, dy);
    188   test::ExpectTensorEqual<int32>(dx[dx.size() - 1], test::AsScalar(0));
    189   test::ExpectClose(
    190       dx[0],
    191       test::AsTensor<float>({0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.,
    192                              10., 11., 12., 13., 14., 20., 21., 22., 23., 24.,
    193                              25., 26., 27., 28., 29., 30., 31., 32., 33., 34.},
    194                             {2, 3, 5}));
    195   test::ExpectClose(dx[1], test::AsTensor<float>({15., 16., 17., 18., 19., 35.,
    196                                                   36., 37., 38., 39.},
    197                                                  {2, 1, 5}));
    198 
    199   // Test ConcatV2 with negative concat axis.
    200   dx = ConcatGradV2(-2, x0, x1, dy);
    201   test::ExpectTensorEqual<int32>(dx[dx.size() - 1], test::AsScalar(0));
    202   test::ExpectClose(
    203       dx[0],
    204       test::AsTensor<float>({0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.,
    205                              10., 11., 12., 13., 14., 20., 21., 22., 23., 24.,
    206                              25., 26., 27., 28., 29., 30., 31., 32., 33., 34.},
    207                             {2, 3, 5}));
    208   test::ExpectClose(dx[1], test::AsTensor<float>({15., 16., 17., 18., 19., 35.,
    209                                                   36., 37., 38., 39.},
    210                                                  {2, 1, 5}));
    211 }
    212 
    213 std::vector<Tensor> SplitGrad(int dim, const Tensor& x, const Tensor& dy0,
    214                               const Tensor& dy1) {
    215   auto T = DT_FLOAT;
    216   auto gdef = test::function::GDef(
    217       {f::NDef("dim", "Placeholder", {}, {{"dtype", DT_INT32}}),
    218        f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
    219        f::NDef("dy0", "Placeholder", {}, {{"dtype", T}}),
    220        f::NDef("dy1", "Placeholder", {}, {{"dtype", T}}),
    221        f::NDef("dx", "SymbolicGradient", {"dim", "x", "dy0", "dy1"},
    222                {{"f", FDH::FunctionRef(
    223                           "Split",
    224                           {{"split_dim", dim}, {"num_split", 2}, {"T", T}})},
    225                 {"Tin", DataTypeSlice{DT_INT32, T, T, T}},
    226                 {"Tout", DataTypeSlice{DT_INT32, T}}})});
    227   VLOG(1) << DebugStringWhole(gdef);
    228   auto sess = NewSession();
    229   TF_CHECK_OK(sess->Create(gdef));
    230   std::vector<Tensor> out;
    231   TF_CHECK_OK(sess->Run({{"dim", test::AsScalar(dim)},
    232                          {"x:0", x},
    233                          {"dy0:0", dy0},
    234                          {"dy1:0", dy1}},
    235                         {"dx:0", "dx:1"}, {}, &out));
    236   CHECK_EQ(out.size(), 2);
    237   TF_CHECK_OK(sess->Close());
    238   return out;
    239 }
    240 
    241 TEST(ArrayGradTest, SplitGrad) {
    242   Tensor x(DT_FLOAT, {2, 4, 5});
    243   x.flat<float>().setZero();
    244   Tensor dy0(DT_FLOAT, {2, 2, 5});
    245   Tensor dy1(DT_FLOAT, {2, 2, 5});
    246   test::FillIota<float>(&dy0, 0);
    247   test::FillIota<float>(&dy1, 100);
    248   auto dx = SplitGrad(1, x, dy0, dy1);
    249   test::ExpectTensorEqual<int32>(dx[0], test::AsScalar(0));
    250   test::ExpectClose(
    251       dx[1], test::AsTensor<float>(
    252                  {0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,
    253                   100., 101., 102., 103., 104., 105., 106., 107., 108., 109.,
    254                   10.,  11.,  12.,  13.,  14.,  15.,  16.,  17.,  18.,  19.,
    255                   110., 111., 112., 113., 114., 115., 116., 117., 118., 119.},
    256                  {2, 4, 5}));
    257 }
    258 
    259 std::vector<Tensor> ReshapeGrad(const Tensor& x, const Tensor& s,
    260                                 const Tensor& dy) {
    261   auto T = DT_FLOAT;
    262   auto gdef = test::function::GDef(
    263       {f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
    264        f::NDef("s", "Placeholder", {}, {{"dtype", DT_INT32}}),
    265        f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
    266        f::NDef("dx", "SymbolicGradient", {"x", "s", "dy"},
    267                {{"f", FDH::FunctionRef("Reshape", {{"T", T}})},
    268                 {"Tin", DataTypeSlice{T, DT_INT32, T}},
    269                 {"Tout", DataTypeSlice{T, DT_INT32}}})});
    270   VLOG(1) << DebugStringWhole(gdef);
    271   auto sess = NewSession();
    272   TF_CHECK_OK(sess->Create(gdef));
    273   std::vector<Tensor> out;
    274   TF_CHECK_OK(sess->Run({{"x:0", x}, {"s:0", s}, {"dy:0", dy}},
    275                         {"dx:0", "dx:1"}, {}, &out));
    276   CHECK_EQ(out.size(), 2);
    277   TF_CHECK_OK(sess->Close());
    278   return out;
    279 }
    280 
    281 TEST(ArrayGradTest, ReshapeGrad) {
    282   Tensor x(DT_FLOAT, {2, 4, 5});
    283   x.flat<float>().setZero();
    284   auto s = test::AsTensor<int32>({8, 5});
    285   Tensor dy(DT_FLOAT, {8, 5});
    286   test::FillIota<float>(&dy, 73);
    287   auto dx = ReshapeGrad(x, s, dy);
    288   test::ExpectClose(
    289       dx[0], test::AsTensor<float>(
    290                  {73.,  74.,  75.,  76.,  77.,  78.,  79.,  80.,  81.,  82.,
    291                   83.,  84.,  85.,  86.,  87.,  88.,  89.,  90.,  91.,  92.,
    292                   93.,  94.,  95.,  96.,  97.,  98.,  99.,  100., 101., 102.,
    293                   103., 104., 105., 106., 107., 108., 109., 110., 111., 112.},
    294                  {2, 4, 5}));
    295   test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0}));
    296 }
    297 
    298 std::vector<Tensor> ExpandDimsGrad(const Tensor& x, const Tensor& s,
    299                                    const Tensor& dy) {
    300   auto T = DT_FLOAT;
    301   auto gdef = test::function::GDef(
    302       {f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
    303        f::NDef("s", "Placeholder", {}, {{"dtype", DT_INT32}}),
    304        f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
    305        f::NDef("dx", "SymbolicGradient", {"x", "s", "dy"},
    306                {{"f", FDH::FunctionRef("ExpandDims", {{"T", T}})},
    307                 {"Tin", DataTypeSlice{T, DT_INT32, T}},
    308                 {"Tout", DataTypeSlice{T, DT_INT32}}})});
    309   VLOG(1) << DebugStringWhole(gdef);
    310   auto sess = NewSession();
    311   TF_CHECK_OK(sess->Create(gdef));
    312   std::vector<Tensor> out;
    313   TF_CHECK_OK(sess->Run({{"x:0", x}, {"s:0", s}, {"dy:0", dy}},
    314                         {"dx:0", "dx:1"}, {}, &out));
    315   CHECK_EQ(out.size(), 2);
    316   TF_CHECK_OK(sess->Close());
    317   return out;
    318 }
    319 
    320 TEST(ArrayGradTest, ExpandDimsGrad) {
    321   Tensor x(DT_FLOAT, {2, 4, 5});
    322   x.flat<float>().setZero();
    323   auto s = test::AsTensor<int32>({1});
    324   Tensor dy(DT_FLOAT, {2, 1, 4, 5});
    325   test::FillIota<float>(&dy, 73);
    326   auto dx = ExpandDimsGrad(x, s, dy);
    327   test::ExpectClose(
    328       dx[0], test::AsTensor<float>(
    329                  {73.,  74.,  75.,  76.,  77.,  78.,  79.,  80.,  81.,  82.,
    330                   83.,  84.,  85.,  86.,  87.,  88.,  89.,  90.,  91.,  92.,
    331                   93.,  94.,  95.,  96.,  97.,  98.,  99.,  100., 101., 102.,
    332                   103., 104., 105., 106., 107., 108., 109., 110., 111., 112.},
    333                  {2, 4, 5}));
    334   test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0}));
    335 }
    336 
    337 std::vector<Tensor> SqueezeGrad(const Tensor& x, const Tensor& dy) {
    338   auto T = DT_FLOAT;
    339   auto gdef = test::function::GDef(
    340       {f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
    341        f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
    342        f::NDef("dx", "SymbolicGradient", {"x", "dy"},
    343                {{"f", FDH::FunctionRef("Squeeze", {{"T", T}})},
    344                 {"Tin", DataTypeSlice{T, T}},
    345                 {"Tout", DataTypeSlice{T}}})});
    346   VLOG(1) << DebugStringWhole(gdef);
    347   auto sess = NewSession();
    348   TF_CHECK_OK(sess->Create(gdef));
    349   std::vector<Tensor> out;
    350   TF_CHECK_OK(sess->Run({{"x:0", x}, {"dy:0", dy}}, {"dx:0"}, {}, &out));
    351   CHECK_EQ(out.size(), 1);
    352   TF_CHECK_OK(sess->Close());
    353   return out;
    354 }
    355 
    356 TEST(ArrayGradTest, SqueezeGrad) {
    357   Tensor x(DT_FLOAT, {2, 1, 3});
    358   x.flat<float>().setZero();
    359   Tensor dy(DT_FLOAT, {2, 3});
    360   test::FillIota<float>(&dy, 1);
    361   auto dx = SqueezeGrad(x, dy);
    362   test::ExpectClose(dx[0],
    363                     test::AsTensor<float>({1., 2., 3., 4., 5., 6.}, {2, 1, 3}));
    364 }
    365 
    366 std::vector<Tensor> TransposeGrad(const Tensor& x, const Tensor& p,
    367                                   const Tensor& dy) {
    368   auto T = DT_FLOAT;
    369   auto gdef = test::function::GDef(
    370       {f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
    371        f::NDef("p", "Placeholder", {}, {{"dtype", DT_INT32}}),
    372        f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
    373        f::NDef("dx", "SymbolicGradient", {"x", "p", "dy"},
    374                {{"f", FDH::FunctionRef("Transpose", {{"T", T}})},
    375                 {"Tin", DataTypeSlice{T, DT_INT32, T}},
    376                 {"Tout", DataTypeSlice{T, DT_INT32}}})});
    377   VLOG(1) << DebugStringWhole(gdef);
    378   auto sess = NewSession();
    379   TF_CHECK_OK(sess->Create(gdef));
    380   std::vector<Tensor> out;
    381   TF_CHECK_OK(sess->Run({{"x:0", x}, {"p:0", p}, {"dy:0", dy}},
    382                         {"dx:0", "dx:1"}, {}, &out));
    383   CHECK_EQ(out.size(), 2);
    384   TF_CHECK_OK(sess->Close());
    385   return out;
    386 }
    387 
    388 TEST(ArrayGradTest, TransposeGrad) {
    389   Tensor x(DT_FLOAT, {2, 4, 5});
    390   x.flat<float>().setZero();
    391   auto p = test::AsTensor<int32>({2, 0, 1});
    392   Tensor dy(DT_FLOAT, {5, 2, 4});
    393   test::FillIota<float>(&dy, 0);
    394   auto dx = TransposeGrad(x, p, dy);
    395   test::ExpectClose(dx[0], test::AsTensor<float>(
    396                                {0., 8.,  16., 24., 32., 1., 9.,  17., 25., 33.,
    397                                 2., 10., 18., 26., 34., 3., 11., 19., 27., 35.,
    398                                 4., 12., 20., 28., 36., 5., 13., 21., 29., 37.,
    399                                 6., 14., 22., 30., 38., 7., 15., 23., 31., 39.},
    400                                {2, 4, 5}));
    401   test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0, 0}));
    402 }
    403 
    404 std::vector<Tensor> ReverseGrad(const Tensor& x, const Tensor& dims,
    405                                 const Tensor& dy) {
    406   auto T = DT_FLOAT;
    407   auto gdef = test::function::GDef(
    408       {f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
    409        f::NDef("dims", "Placeholder", {}, {{"dtype", DT_BOOL}}),
    410        f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
    411        f::NDef("dx", "SymbolicGradient", {"x", "dims", "dy"},
    412                {{"f", FDH::FunctionRef("Reverse", {{"T", T}})},
    413                 {"Tin", DataTypeSlice{T, DT_BOOL, T}},
    414                 {"Tout", DataTypeSlice{T, DT_BOOL}}})});
    415   VLOG(1) << DebugStringWhole(gdef);
    416   auto sess = NewSession();
    417   TF_CHECK_OK(sess->Create(gdef));
    418   std::vector<Tensor> out;
    419   TF_CHECK_OK(sess->Run({{"x:0", x}, {"dims:0", dims}, {"dy:0", dy}},
    420                         {"dx:0", "dx:1"}, {}, &out));
    421   CHECK_EQ(out.size(), 2);
    422   TF_CHECK_OK(sess->Close());
    423   return out;
    424 }
    425 
    426 TEST(ArrayGradTest, ReverseGrad) {
    427   Tensor x(DT_FLOAT, {2, 3});
    428   x.flat<float>().setZero();
    429   auto dims = test::AsTensor<bool>({false, true});
    430   Tensor dy(DT_FLOAT, {2, 3});
    431   test::FillIota<float>(&dy, 1);
    432   auto dx = ReverseGrad(x, dims, dy);
    433   test::ExpectClose(dx[0],
    434                     test::AsTensor<float>({3., 2., 1., 6., 5., 4.}, {2, 3}));
    435   test::ExpectTensorEqual<bool>(dx[1], test::AsTensor<bool>({false, false}));
    436 }
    437 
    438 std::vector<Tensor> ReverseV2Grad(const Tensor& x, const Tensor& axis,
    439                                   const Tensor& dy) {
    440   auto T = DT_FLOAT;
    441   auto Tidx = DT_INT32;
    442   auto gdef = test::function::GDef(
    443       {f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
    444        f::NDef("axis", "Placeholder", {}, {{"dtype", DT_INT32}}),
    445        f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
    446        f::NDef(
    447            "dx", "SymbolicGradient", {"x", "axis", "dy"},
    448            {{"f", FDH::FunctionRef("ReverseV2", {{"T", T}, {"Tidx", Tidx}})},
    449             {"Tin", DataTypeSlice{T, DT_INT32, T}},
    450             {"Tout", DataTypeSlice{T, DT_INT32}}})});
    451   VLOG(1) << DebugStringWhole(gdef);
    452   auto sess = NewSession();
    453   TF_CHECK_OK(sess->Create(gdef));
    454   std::vector<Tensor> out;
    455   TF_CHECK_OK(sess->Run({{"x:0", x}, {"axis:0", axis}, {"dy:0", dy}},
    456                         {"dx:0", "dx:1"}, {}, &out));
    457   CHECK_EQ(out.size(), 2);
    458   TF_CHECK_OK(sess->Close());
    459   return out;
    460 }
    461 
    462 TEST(ArrayGradTest, ReverseV2Grad) {
    463   Tensor x(DT_FLOAT, {2, 3});
    464   x.flat<float>().setZero();
    465   auto axis = test::AsTensor<int32>({1});
    466   Tensor dy(DT_FLOAT, {2, 3});
    467   test::FillIota<float>(&dy, 1);
    468   auto dx = ReverseV2Grad(x, axis, dy);
    469   test::ExpectTensorEqual<float>(
    470       dx[0], test::AsTensor<float>({3., 2., 1., 6., 5., 4.}, {2, 3}));
    471   test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0}));
    472 }
    473 
    474 std::vector<Tensor> SliceGrad(const Tensor& x, const Tensor& b, const Tensor& s,
    475                               const Tensor& dy) {
    476   auto T = DT_FLOAT;
    477   auto gdef = test::function::GDef(
    478       {f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
    479        f::NDef("b", "Placeholder", {}, {{"dtype", DT_INT32}}),
    480        f::NDef("s", "Placeholder", {}, {{"dtype", DT_INT32}}),
    481        f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
    482        f::NDef(
    483            "dx", "SymbolicGradient", {"x", "b", "s", "dy"},
    484            {{"f", FDH::FunctionRef("Slice", {{"T", T}, {"Index", DT_INT32}})},
    485             {"Tin", DataTypeSlice{T, DT_INT32, DT_INT32, T}},
    486             {"Tout", DataTypeSlice{T, DT_INT32, DT_INT32}}})});
    487   VLOG(1) << DebugStringWhole(gdef);
    488   auto sess = NewSession();
    489   TF_CHECK_OK(sess->Create(gdef));
    490   std::vector<Tensor> out;
    491   TF_CHECK_OK(sess->Run({{"x:0", x}, {"b:0", b}, {"s:0", s}, {"dy:0", dy}},
    492                         {"dx:0", "dx:1", "dx:2"}, {}, &out));
    493   CHECK_EQ(out.size(), 3);
    494   TF_CHECK_OK(sess->Close());
    495   return out;
    496 }
    497 
    498 TEST(ArrayGradTest, SliceGrad) {
    499   Tensor x(DT_FLOAT, {2, 3, 4});
    500   x.flat<float>().setZero();
    501   auto begin = test::AsTensor<int32>({1, 1, 1});
    502   auto size = test::AsTensor<int32>({1, 2, 2});
    503   Tensor dy(DT_FLOAT, {1, 2, 2});
    504   test::FillIota<float>(&dy, 1);
    505   auto dx = SliceGrad(x, begin, size, dy);
    506   test::ExpectClose(dx[0],
    507                     test::AsTensor<float>(
    508                         {
    509                             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    510                             0., 0., 0., 0., 0., 1., 2., 0., 0., 3., 4., 0.,
    511                         },
    512                         {2, 3, 4}));
    513   test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0, 0}));
    514   test::ExpectTensorEqual<int32>(dx[2], test::AsTensor<int32>({0, 0, 0}));
    515 }
    516 
    517 std::vector<Tensor> StridedSliceGrad(const Tensor& x, const Tensor& begin,
    518                                      const Tensor& end, const Tensor& strides,
    519                                      const Tensor& dy, int32 begin_mask,
    520                                      int32 end_mask, int32 ellipsis_mask,
    521                                      int32 new_axis_mask,
    522                                      int32 shrink_axis_mask) {
    523   auto T = DT_FLOAT;
    524   auto gdef = test::function::GDef(
    525       {f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
    526        f::NDef("begin", "Placeholder", {}, {{"dtype", DT_INT32}}),
    527        f::NDef("end", "Placeholder", {}, {{"dtype", DT_INT32}}),
    528        f::NDef("strides", "Placeholder", {}, {{"dtype", DT_INT32}}),
    529        f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
    530        f::NDef(
    531            "dx", "SymbolicGradient", {"x", "begin", "end", "strides", "dy"},
    532            {{"f", FDH::FunctionRef("StridedSlice",
    533                                    {
    534                                        {"T", T},
    535                                        {"Index", DT_INT32},
    536                                        {"begin_mask", begin_mask},
    537                                        {"end_mask", end_mask},
    538                                        {"new_axis_mask", new_axis_mask},
    539                                        {"shrink_axis_mask", shrink_axis_mask},
    540                                        {"ellipsis_mask", ellipsis_mask},
    541                                    })},
    542             {"Tin", DataTypeSlice{T, DT_INT32, DT_INT32, DT_INT32, T}},
    543             {"Tout", DataTypeSlice{T, DT_INT32, DT_INT32, DT_INT32}}})});
    544   VLOG(1) << DebugStringWhole(gdef);
    545   auto sess = NewSession();
    546   TF_CHECK_OK(sess->Create(gdef));
    547   std::vector<Tensor> out;
    548   TF_CHECK_OK(sess->Run({{"x:0", x},
    549                          {"begin:0", begin},
    550                          {"end:0", end},
    551                          {"strides:0", strides},
    552                          {"dy:0", dy}},
    553                         {"dx:0", "dx:1", "dx:2", "dx:3"}, {}, &out));
    554   CHECK_EQ(out.size(), 4);
    555   TF_CHECK_OK(sess->Close());
    556   return out;
    557 }
    558 
    559 std::vector<Tensor> StridedSliceGradGrad(
    560     const Tensor& shape, const Tensor& begin, const Tensor& end,
    561     const Tensor& strides, const Tensor& dy, const Tensor& grad,
    562     int32 begin_mask, int32 end_mask, int32 ellipsis_mask, int32 new_axis_mask,
    563     int32 shrink_axis_mask) {
    564   auto T = DT_FLOAT;
    565   auto gdef = test::function::GDef(
    566       {f::NDef("shape", "Placeholder", {}, {{"dtype", DT_INT32}}),
    567        f::NDef("begin", "Placeholder", {}, {{"dtype", DT_INT32}}),
    568        f::NDef("end", "Placeholder", {}, {{"dtype", DT_INT32}}),
    569        f::NDef("strides", "Placeholder", {}, {{"dtype", DT_INT32}}),
    570        f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
    571        f::NDef("grad", "Placeholder", {}, {{"dtype", T}}),
    572        f::NDef(
    573            "dx", "SymbolicGradient",
    574            {"shape", "begin", "end", "strides", "dy", "grad"},
    575            {{"f", FDH::FunctionRef("StridedSliceGrad",
    576                                    {
    577                                        {"T", T},
    578                                        {"Index", DT_INT32},
    579                                        {"begin_mask", begin_mask},
    580                                        {"end_mask", end_mask},
    581                                        {"new_axis_mask", new_axis_mask},
    582                                        {"shrink_axis_mask", shrink_axis_mask},
    583                                        {"ellipsis_mask", ellipsis_mask},
    584                                    })},
    585             {"Tin",
    586              DataTypeSlice{DT_INT32, DT_INT32, DT_INT32, DT_INT32, T, T}},
    587             {"Tout",
    588              DataTypeSlice{DT_INT32, DT_INT32, DT_INT32, DT_INT32, T}}})});
    589   VLOG(1) << DebugStringWhole(gdef);
    590   auto sess = NewSession();
    591   TF_CHECK_OK(sess->Create(gdef));
    592   std::vector<Tensor> out;
    593   TF_CHECK_OK(sess->Run({{"shape:0", shape},
    594                          {"begin:0", begin},
    595                          {"end:0", end},
    596                          {"strides:0", strides},
    597                          {"dy:0", dy},
    598                          {"grad:0", grad}},
    599                         {"dx:0", "dx:1", "dx:2", "dx:3", "dx:4"}, {}, &out));
    600   CHECK_EQ(out.size(), 5);
    601   TF_CHECK_OK(sess->Close());
    602   return out;
    603 }
    604 
    605 TEST(ArrayGradTest, StridedSliceGrad) {
    606   Tensor x(DT_FLOAT, {2, 3, 4});
    607   x.flat<float>().setZero();
    608   Tensor x_shape = test::AsTensor<int32>({2, 3, 4}, {3});
    609 
    610   {
    611     auto start = test::AsTensor<int32>({1, 1, 1});
    612     auto stop = test::AsTensor<int32>({2, 3, 3});
    613     auto strides = test::AsTensor<int32>({1, 1, 1});
    614     Tensor dy(DT_FLOAT, {1, 2, 2});
    615     test::FillIota<float>(&dy, 1);
    616     int begin_mask = 0, end_mask = 0, new_axis_mask = 0, shrink_axis_mask = 0,
    617         ellipsis_mask = 0;
    618     auto dx =
    619         StridedSliceGrad(x, start, stop, strides, dy, begin_mask, end_mask,
    620                          ellipsis_mask, new_axis_mask, shrink_axis_mask);
    621     test::ExpectClose(dx[0],
    622                       test::AsTensor<float>(
    623                           {
    624                               0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    625                               0., 0., 0., 0., 0., 1., 2., 0., 0., 3., 4., 0.,
    626                           },
    627                           {2, 3, 4}));
    628     test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0, 0}));
    629     test::ExpectTensorEqual<int32>(dx[2], test::AsTensor<int32>({0, 0, 0}));
    630     auto ddx = StridedSliceGradGrad(x_shape, start, stop, strides, dy, dx[0],
    631                                     begin_mask, end_mask, ellipsis_mask,
    632                                     new_axis_mask, shrink_axis_mask);
    633     test::ExpectClose(ddx[4], dy);
    634   }
    635 
    636   // test equivalent of python tf.gradients(foo[1:2, 1:3, 1:3])
    637   {
    638     auto start = test::AsTensor<int32>({1, 1, 1});
    639     auto stop = test::AsTensor<int32>({2, 3, 3});
    640     auto strides = test::AsTensor<int32>({1, 1, 1});
    641     Tensor dy(DT_FLOAT, {1, 2, 2});
    642     test::FillIota<float>(&dy, 1);
    643     int begin_mask = 0, end_mask = 0, new_axis_mask = 0, shrink_axis_mask = 0,
    644         ellipsis_mask = 0;
    645     auto dx =
    646         StridedSliceGrad(x, start, stop, strides, dy, begin_mask, end_mask,
    647                          ellipsis_mask, new_axis_mask, shrink_axis_mask);
    648     test::ExpectClose(dx[0],
    649                       test::AsTensor<float>(
    650                           {
    651                               0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    652                               0., 0., 0., 0., 0., 1., 2., 0., 0., 3., 4., 0.,
    653                           },
    654                           {2, 3, 4}));
    655     test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0, 0}));
    656     test::ExpectTensorEqual<int32>(dx[2], test::AsTensor<int32>({0, 0, 0}));
    657     auto ddx = StridedSliceGradGrad(x_shape, start, stop, strides, dy, dx[0],
    658                                     begin_mask, end_mask, ellipsis_mask,
    659                                     new_axis_mask, shrink_axis_mask);
    660     test::ExpectClose(ddx[4], dy);
    661   }
    662 
    663   // test equivalent of python tf.gradients(foo[1, 1:, :-2, None])
    664   {
    665     int dontcare = 66;
    666     auto start = test::AsTensor<int32>({1, 1, dontcare, dontcare});
    667     auto stop = test::AsTensor<int32>({2, dontcare, -2, dontcare});
    668     auto strides = test::AsTensor<int32>({1, 1, 1, dontcare});
    669     Tensor dy(DT_FLOAT, {2, 2, 1});
    670     test::FillIota<float>(&dy, 1);
    671     int begin_mask = 4, end_mask = 2, new_axis_mask = 8, shrink_axis_mask = 1,
    672         ellipsis_mask = 0;
    673     auto dx =
    674         StridedSliceGrad(x, start, stop, strides, dy, begin_mask, end_mask,
    675                          ellipsis_mask, new_axis_mask, shrink_axis_mask);
    676     test::ExpectClose(dx[0],
    677                       test::AsTensor<float>(
    678                           {
    679                               0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    680                               0., 0., 0., 0., 1., 2., 0., 0., 3., 4., 0., 0.,
    681                           },
    682                           {2, 3, 4}));
    683     test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0, 0, 0}));
    684     test::ExpectTensorEqual<int32>(dx[2], test::AsTensor<int32>({0, 0, 0, 0}));
    685     auto ddx = StridedSliceGradGrad(x_shape, start, stop, strides, dy, dx[0],
    686                                     begin_mask, end_mask, ellipsis_mask,
    687                                     new_axis_mask, shrink_axis_mask);
    688     test::ExpectClose(ddx[4], dy);
    689   }
    690 
    691   // test equivalent of tf.gradients(foo[1, ...]) i.e. foo[1, 0:3, 0:4]
    692   {
    693     int dontcare = 66;
    694     auto start = test::AsTensor<int32>({1, dontcare});
    695     auto stop = test::AsTensor<int32>({2, dontcare});
    696     auto strides = test::AsTensor<int32>({1, 1});
    697     Tensor dy(DT_FLOAT, {3, 4});
    698     test::FillIota<float>(&dy, 1);
    699     int begin_mask = 0, end_mask = 0, new_axis_mask = 0, shrink_axis_mask = 1,
    700         ellipsis_mask = 2;
    701     auto dx =
    702         StridedSliceGrad(x, start, stop, strides, dy, begin_mask, end_mask,
    703                          ellipsis_mask, new_axis_mask, shrink_axis_mask);
    704     test::ExpectClose(dx[0],
    705                       test::AsTensor<float>(
    706                           {
    707                               0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,  0.,  0.,
    708                               1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.,
    709                           },
    710                           {2, 3, 4}));
    711     test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0}));
    712     test::ExpectTensorEqual<int32>(dx[2], test::AsTensor<int32>({0, 0}));
    713     auto ddx = StridedSliceGradGrad(x_shape, start, stop, strides, dy, dx[0],
    714                                     begin_mask, end_mask, ellipsis_mask,
    715                                     new_axis_mask, shrink_axis_mask);
    716     test::ExpectClose(ddx[4], dy);
    717   }
    718 }
    719 
    720 }  // namespace
    721 }  // namespace tensorflow
    722