Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7 http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <complex>
     17 #include <string>
     18 
     19 #include "tensorflow/core/framework/fake_input.h"
     20 #include "tensorflow/core/framework/node_def_builder.h"
     21 #include "tensorflow/core/framework/tensor.h"
     22 #include "tensorflow/core/framework/tensor_shape.h"
     23 #include "tensorflow/core/framework/types.h"
     24 #include "tensorflow/core/framework/types.pb.h"
     25 #include "tensorflow/core/kernels/ops_testutil.h"
     26 #include "tensorflow/core/lib/core/status.h"
     27 #include "tensorflow/core/lib/io/path.h"
     28 #include "tensorflow/core/platform/env.h"
     29 #include "tensorflow/core/platform/test.h"
     30 #include "tensorflow/core/platform/types.h"
     31 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
     32 
     33 namespace tensorflow {
     34 namespace {
     35 
     36 class SaveV2OpTest : public OpsTestBase {
     37  protected:
     38   void MakeOp() {
     39     TF_ASSERT_OK(NodeDefBuilder("myop", "SaveV2")
     40                      .Input(FakeInput())  // prefix
     41                      .Input(FakeInput())  // tensor_names
     42                      .Input(FakeInput())  // shape_and_slices
     43                      .Input(FakeInput({DT_BOOL, DT_INT32, DT_FLOAT, DT_DOUBLE,
     44                                        DT_QINT8, DT_QINT32, DT_UINT8, DT_INT8,
     45                                        DT_INT16, DT_INT64, DT_COMPLEX64,
     46                                        DT_COMPLEX128, DT_HALF}))  // tensors
     47                      .Finalize(node_def()));
     48     TF_ASSERT_OK(InitOp());
     49   }
     50 };
     51 
     52 TEST_F(SaveV2OpTest, Simple) {
     53   const string prefix = io::JoinPath(testing::TmpDir(), "tensor_simple");
     54   const string tensornames[] = {
     55       "tensor_bool",  "tensor_int",    "tensor_float",     "tensor_double",
     56       "tensor_qint8", "tensor_qint32", "tensor_uint8",     "tensor_int8",
     57       "tensor_int16", "tensor_int64",  "tensor_complex64", "tensor_complex128",
     58       "tensor_half"};
     59 
     60   MakeOp();
     61   // Add a file name
     62   AddInput<string>(TensorShape({}),
     63                    [&prefix](int x) -> string { return prefix; });
     64 
     65   // Add the tensor names
     66   AddInput<string>(TensorShape({13}),
     67                    [&tensornames](int x) -> string { return tensornames[x]; });
     68 
     69   // Add the slice specs
     70   AddInput<string>(TensorShape({13}), [&tensornames](int x) -> string {
     71     return "" /* saves in full */;
     72   });
     73 
     74   // Add a 1-d bool tensor
     75   AddInput<bool>(TensorShape({2}), [](int x) -> bool { return x != 0; });
     76 
     77   // Add a 1-d integer tensor
     78   AddInput<int32>(TensorShape({10}), [](int x) -> int32 { return x + 1; });
     79 
     80   // Add a 2-d float tensor
     81   AddInput<float>(TensorShape({2, 4}),
     82                   [](int x) -> float { return static_cast<float>(x) / 10; });
     83 
     84   // Add a 2-d double tensor
     85   AddInput<double>(TensorShape({2, 4}),
     86                    [](int x) -> double { return static_cast<double>(x) / 20; });
     87 
     88   // Add a 2-d qint8 tensor
     89   AddInput<qint8>(TensorShape({3, 2}),
     90                   [](int x) -> qint8 { return *reinterpret_cast<qint8*>(&x); });
     91 
     92   // Add a 2-d qint32 tensor
     93   AddInput<qint32>(TensorShape({2, 3}), [](int x) -> qint32 {
     94     return *reinterpret_cast<qint32*>(&x) * qint8(2);
     95   });
     96 
     97   // Add a 1-d uint8 tensor
     98   AddInput<uint8>(TensorShape({11}), [](int x) -> uint8 { return x + 1; });
     99 
    100   // Add a 1-d int8 tensor
    101   AddInput<int8>(TensorShape({7}), [](int x) -> int8 { return x - 7; });
    102 
    103   // Add a 1-d int16 tensor
    104   AddInput<int16>(TensorShape({7}), [](int x) -> int16 { return x - 8; });
    105 
    106   // Add a 1-d int64 tensor
    107   AddInput<int64>(TensorShape({9}), [](int x) -> int64 { return x - 9; });
    108 
    109   // Add a 2-d complex64 tensor
    110   AddInput<complex64>(TensorShape({2, 3}), [](int x) -> complex64 {
    111     return complex64(100 + x, 200 + x);
    112   });
    113 
    114   // Add a 2-d complex128 tensor
    115   AddInput<complex128>(TensorShape({2, 3}), [](int x) -> complex128 {
    116     return complex128(100 + x, 200 + x);
    117   });
    118 
    119   // Add a 2-d half tensor
    120   AddInput<Eigen::half>(TensorShape({2, 4}), [](int x) -> Eigen::half {
    121     return static_cast<Eigen::half>(x) / Eigen::half(2);
    122   });
    123   TF_ASSERT_OK(RunOpKernel());
    124 
    125   // Check that the checkpoint file is properly written
    126   BundleReader reader(Env::Default(), prefix);
    127   TF_EXPECT_OK(reader.status());
    128 
    129   // We expect to find all saved tensors
    130   {
    131     // The 1-d bool tensor
    132     TensorShape shape;
    133     TF_EXPECT_OK(reader.LookupTensorShape("tensor_bool", &shape));
    134     TensorShape expected({2});
    135     EXPECT_TRUE(shape.IsSameSize(expected));
    136 
    137     // We expect the tensor value to be correct.
    138     Tensor val;
    139     TF_EXPECT_OK(reader.Lookup("tensor_bool", &val));
    140     EXPECT_EQ(DT_BOOL, val.dtype());
    141     for (int i = 0; i < 2; ++i) {
    142       EXPECT_EQ((i != 0), val.template flat<bool>()(i));
    143     }
    144   }
    145 
    146   {
    147     // The 1-d integer tensor
    148     TensorShape shape;
    149     TF_EXPECT_OK(reader.LookupTensorShape("tensor_int", &shape));
    150     TensorShape expected({10});
    151     EXPECT_TRUE(shape.IsSameSize(expected));
    152 
    153     // We expect the tensor value to be correct.
    154     Tensor val;
    155     TF_EXPECT_OK(reader.Lookup("tensor_int", &val));
    156     EXPECT_EQ(DT_INT32, val.dtype());
    157     for (int i = 0; i < 10; ++i) {
    158       EXPECT_EQ(i + 1, val.template flat<int>()(i));
    159     }
    160   }
    161 
    162   {
    163     // The 2-d float tensor
    164     TensorShape shape;
    165     TF_EXPECT_OK(reader.LookupTensorShape("tensor_float", &shape));
    166     TensorShape expected({2, 4});
    167     EXPECT_TRUE(shape.IsSameSize(expected));
    168 
    169     // We expect the tensor value to be correct.
    170     Tensor val;
    171     TF_EXPECT_OK(reader.Lookup("tensor_float", &val));
    172     EXPECT_EQ(DT_FLOAT, val.dtype());
    173     for (int i = 0; i < 8; ++i) {
    174       EXPECT_EQ(static_cast<float>(i) / 10, val.template flat<float>()(i));
    175     }
    176   }
    177 
    178   {
    179     // The 2-d double tensor
    180     TensorShape shape;
    181     TF_EXPECT_OK(reader.LookupTensorShape("tensor_double", &shape));
    182     TensorShape expected({2, 4});
    183     EXPECT_TRUE(shape.IsSameSize(expected));
    184 
    185     // We expect the tensor value to be correct.
    186     Tensor val;
    187     TF_EXPECT_OK(reader.Lookup("tensor_double", &val));
    188     EXPECT_EQ(DT_DOUBLE, val.dtype());
    189     for (int i = 0; i < 8; ++i) {
    190       EXPECT_EQ(static_cast<double>(i) / 20, val.template flat<double>()(i));
    191     }
    192   }
    193 
    194   {
    195     // The 2-d qint8 tensor
    196     TensorShape shape;
    197     TF_EXPECT_OK(reader.LookupTensorShape("tensor_qint8", &shape));
    198     TensorShape expected({3, 2});
    199     EXPECT_TRUE(shape.IsSameSize(expected));
    200 
    201     // We expect the tensor value to be correct.
    202     Tensor val;
    203     TF_EXPECT_OK(reader.Lookup("tensor_qint8", &val));
    204     EXPECT_EQ(DT_QINT8, val.dtype());
    205     for (int i = 0; i < 6; ++i) {
    206       EXPECT_EQ(*reinterpret_cast<qint8*>(&i), val.template flat<qint8>()(i));
    207     }
    208   }
    209 
    210   {
    211     // The 2-d qint32 tensor
    212     TensorShape shape;
    213     TF_EXPECT_OK(reader.LookupTensorShape("tensor_qint32", &shape));
    214     TensorShape expected({2, 3});
    215     EXPECT_TRUE(shape.IsSameSize(expected));
    216 
    217     // We expect the tensor value to be correct.
    218     Tensor val;
    219     TF_EXPECT_OK(reader.Lookup("tensor_qint32", &val));
    220     EXPECT_EQ(DT_QINT32, val.dtype());
    221     for (int i = 0; i < 6; ++i) {
    222       EXPECT_EQ(*reinterpret_cast<qint32*>(&i) * qint8(2),
    223                 val.template flat<qint32>()(i));
    224     }
    225   }
    226 
    227   {
    228     // The 1-d uint8 tensor
    229     TensorShape shape;
    230     TF_EXPECT_OK(reader.LookupTensorShape("tensor_uint8", &shape));
    231     TensorShape expected({11});
    232     EXPECT_TRUE(shape.IsSameSize(expected));
    233 
    234     // We expect the tensor value to be correct.
    235     Tensor val;
    236     TF_EXPECT_OK(reader.Lookup("tensor_uint8", &val));
    237     EXPECT_EQ(DT_UINT8, val.dtype());
    238     for (int i = 0; i < 11; ++i) {
    239       EXPECT_EQ(i + 1, val.template flat<uint8>()(i));
    240     }
    241   }
    242 
    243   {
    244     // The 1-d int8 tensor
    245     TensorShape shape;
    246     TF_EXPECT_OK(reader.LookupTensorShape("tensor_int8", &shape));
    247     TensorShape expected({7});
    248     EXPECT_TRUE(shape.IsSameSize(expected));
    249 
    250     // We expect the tensor value to be correct.
    251     Tensor val;
    252     TF_EXPECT_OK(reader.Lookup("tensor_int8", &val));
    253     EXPECT_EQ(DT_INT8, val.dtype());
    254     for (int i = 0; i < 7; ++i) {
    255       EXPECT_EQ(i - 7, val.template flat<int8>()(i));
    256     }
    257   }
    258 
    259   {
    260     // The 1-d int16 tensor
    261     TensorShape shape;
    262     TF_EXPECT_OK(reader.LookupTensorShape("tensor_int16", &shape));
    263     TensorShape expected({7});
    264     EXPECT_TRUE(shape.IsSameSize(expected));
    265 
    266     // We expect the tensor value to be correct.
    267     Tensor val;
    268     TF_EXPECT_OK(reader.Lookup("tensor_int16", &val));
    269     EXPECT_EQ(DT_INT16, val.dtype());
    270     for (int i = 0; i < 7; ++i) {
    271       EXPECT_EQ(i - 8, val.template flat<int16>()(i));
    272     }
    273   }
    274 
    275   {
    276     // The 1-d int64 tensor
    277     TensorShape shape;
    278     TF_EXPECT_OK(reader.LookupTensorShape("tensor_int64", &shape));
    279     TensorShape expected({9});
    280     EXPECT_TRUE(shape.IsSameSize(expected));
    281 
    282     // We expect the tensor value to be correct.
    283     Tensor val;
    284     TF_EXPECT_OK(reader.Lookup("tensor_int64", &val));
    285     EXPECT_EQ(DT_INT64, val.dtype());
    286     for (int i = 0; i < 9; ++i) {
    287       EXPECT_EQ(i - 9, val.template flat<int64>()(i));
    288     }
    289   }
    290 
    291   {
    292     // The 2-d complex64 tensor
    293     TensorShape shape;
    294     TF_EXPECT_OK(reader.LookupTensorShape("tensor_complex64", &shape));
    295     TensorShape expected({2, 3});
    296     EXPECT_TRUE(shape.IsSameSize(expected));
    297 
    298     // We expect the tensor value to be correct.
    299     Tensor val;
    300     TF_EXPECT_OK(reader.Lookup("tensor_complex64", &val));
    301     EXPECT_EQ(DT_COMPLEX64, val.dtype());
    302     for (int i = 0; i < 6; ++i) {
    303       EXPECT_EQ(100 + i, val.template flat<complex64>()(i).real());
    304       EXPECT_EQ(200 + i, val.template flat<complex64>()(i).imag());
    305     }
    306   }
    307 
    308   {
    309     // The 2-d complex128 tensor
    310     TensorShape shape;
    311     TF_EXPECT_OK(reader.LookupTensorShape("tensor_complex128", &shape));
    312     TensorShape expected({2, 3});
    313     EXPECT_TRUE(shape.IsSameSize(expected));
    314 
    315     // We expect the tensor value to be correct.
    316     Tensor val;
    317     TF_EXPECT_OK(reader.Lookup("tensor_complex128", &val));
    318     EXPECT_EQ(DT_COMPLEX128, val.dtype());
    319     for (int i = 0; i < 6; ++i) {
    320       EXPECT_EQ(100 + i, val.template flat<complex128>()(i).real());
    321       EXPECT_EQ(200 + i, val.template flat<complex128>()(i).imag());
    322     }
    323   }
    324   {
    325     // The 2-d half tensor
    326     TensorShape shape;
    327     TF_EXPECT_OK(reader.LookupTensorShape("tensor_half", &shape));
    328     TensorShape expected({2, 4});
    329     EXPECT_TRUE(shape.IsSameSize(expected));
    330 
    331     // We expect the tensor value to be correct.
    332     Tensor val;
    333     TF_EXPECT_OK(reader.Lookup("tensor_half", &val));
    334     EXPECT_EQ(DT_HALF, val.dtype());
    335     for (int i = 0; i < 8; ++i) {
    336       EXPECT_EQ(static_cast<Eigen::half>(i) / Eigen::half(2),
    337                 val.template flat<Eigen::half>()(i));
    338     }
    339   }
    340 }
    341 
    342 }  // namespace
    343 }  // namespace tensorflow
    344