Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <functional>
     17 #include <memory>
     18 
     19 #include "tensorflow/cc/ops/const_op.h"
     20 #include "tensorflow/cc/ops/io_ops.h"
     21 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
     22 #include "tensorflow/core/framework/allocator.h"
     23 #include "tensorflow/core/framework/fake_input.h"
     24 #include "tensorflow/core/framework/node_def_builder.h"
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 #include "tensorflow/core/framework/tensor.h"
     27 #include "tensorflow/core/framework/types.h"
     28 #include "tensorflow/core/framework/types.pb.h"
     29 #include "tensorflow/core/graph/graph_def_builder.h"
     30 #include "tensorflow/core/kernels/ops_testutil.h"
     31 #include "tensorflow/core/kernels/ops_util.h"
     32 #include "tensorflow/core/lib/io/path.h"
     33 #include "tensorflow/core/lib/strings/strcat.h"
     34 #include "tensorflow/core/platform/test.h"
     35 #include "tensorflow/core/platform/test_benchmark.h"
     36 #include "tensorflow/core/platform/types.h"
     37 #include "tensorflow/core/protobuf/config.pb.h"
     38 #include "tensorflow/core/util/tensor_slice_reader.h"
     39 
     40 namespace tensorflow {
     41 namespace {
     42 
     43 class SaveOpTest : public OpsTestBase {
     44  protected:
     45   void MakeOp() {
     46     TF_ASSERT_OK(
     47         NodeDefBuilder("myop", "Save")
     48             .Input(FakeInput())
     49             .Input(FakeInput())
     50             .Input(FakeInput({DT_BOOL, DT_INT32, DT_FLOAT, DT_DOUBLE, DT_QINT8,
     51                               DT_QINT32, DT_UINT8, DT_INT8, DT_INT16, DT_INT64,
     52                               DT_STRING, DT_COMPLEX64, DT_COMPLEX128, DT_HALF}))
     53             .Finalize(node_def()));
     54     TF_ASSERT_OK(InitOp());
     55   }
     56 };
     57 
     58 TEST_F(SaveOpTest, Simple) {
     59   const string filename = io::JoinPath(testing::TmpDir(), "tensor_simple");
     60   const string tensornames[] = {
     61       "tensor_bool",       "tensor_int",    "tensor_float",  "tensor_double",
     62       "tensor_qint8",      "tensor_qint32", "tensor_uint8",  "tensor_int8",
     63       "tensor_int16",      "tensor_int64",  "tensor_string", "tensor_complex64",
     64       "tensor_complex128", "tensor_half"};
     65 
     66   MakeOp();
     67   // Add a file name
     68   AddInput<string>(TensorShape({}),
     69                    [&filename](int x) -> string { return filename; });
     70 
     71   // Add the tensor names
     72   AddInput<string>(TensorShape({14}),
     73                    [&tensornames](int x) -> string { return tensornames[x]; });
     74 
     75   // Add a 1-d bool tensor
     76   AddInput<bool>(TensorShape({2}), [](int x) -> bool { return x != 0; });
     77 
     78   // Add a 1-d integer tensor
     79   AddInput<int32>(TensorShape({10}), [](int x) -> int32 { return x + 1; });
     80 
     81   // Add a 2-d float tensor
     82   AddInput<float>(TensorShape({2, 4}),
     83                   [](int x) -> float { return static_cast<float>(x) / 10; });
     84 
     85   // Add a 2-d double tensor
     86   AddInput<double>(TensorShape({2, 4}),
     87                    [](int x) -> double { return static_cast<double>(x) / 20; });
     88 
     89   // Add a 2-d qint8 tensor
     90   AddInput<qint8>(TensorShape({3, 2}),
     91                   [](int x) -> qint8 { return *reinterpret_cast<qint8*>(&x); });
     92 
     93   // Add a 2-d qint32 tensor
     94   AddInput<qint32>(TensorShape({2, 3}), [](int x) -> qint32 {
     95     return *reinterpret_cast<qint32*>(&x) * qint8(2);
     96   });
     97 
     98   // Add a 1-d uint8 tensor
     99   AddInput<uint8>(TensorShape({11}), [](int x) -> uint8 { return x + 1; });
    100 
    101   // Add a 1-d int8 tensor
    102   AddInput<int8>(TensorShape({7}), [](int x) -> int8 { return x - 7; });
    103 
    104   // Add a 1-d int16 tensor
    105   AddInput<int16>(TensorShape({7}), [](int x) -> int16 { return x - 8; });
    106 
    107   // Add a 1-d int64 tensor
    108   AddInput<int64>(TensorShape({9}), [](int x) -> int64 { return x - 9; });
    109 
    110   // Add a 1-d string tensor
    111   AddInput<string>(TensorShape({2}),
    112                    [](int x) -> string { return x ? "yes" : "no"; });
    113 
    114   // Add a 2-d complex64 tensor
    115   AddInput<complex64>(TensorShape({2, 3}), [](int x) -> complex64 {
    116     return complex64(100 + x, 200 + x);
    117   });
    118 
    119   // Add a 2-d complex128 tensor
    120   AddInput<complex128>(TensorShape({2, 3}), [](int x) -> complex128 {
    121     return complex128(100 + x, 200 + x);
    122   });
    123 
    124   // Add a 2-d half tensor
    125   AddInput<Eigen::half>(TensorShape({2, 4}), [](int x) -> Eigen::half {
    126     return static_cast<Eigen::half>(x) / Eigen::half(2);
    127   });
    128   TF_ASSERT_OK(RunOpKernel());
    129 
    130   // Check that the checkpoint file is properly written
    131   checkpoint::TensorSliceReader reader(filename,
    132                                        checkpoint::OpenTableTensorSliceReader);
    133   TF_EXPECT_OK(reader.status());
    134 
    135   // We expect to find all saved tensors
    136   {
    137     // The 1-d bool tensor
    138     TensorShape shape;
    139     DataType type;
    140     EXPECT_TRUE(reader.HasTensor("tensor_bool", &shape, &type));
    141     TensorShape expected({2});
    142     EXPECT_TRUE(shape.IsSameSize(expected));
    143     EXPECT_EQ(DT_BOOL, type);
    144 
    145     // We expect the tensor value to be correct.
    146     TensorSlice s = TensorSlice::ParseOrDie("-");
    147     bool data[2];
    148     std::fill_n(data, 2, false);
    149     EXPECT_TRUE(reader.CopySliceData("tensor_bool", s, data));
    150     for (int i = 0; i < 2; ++i) {
    151       EXPECT_EQ((i != 0), data[i]);
    152     }
    153   }
    154 
    155   {
    156     // The 1-d integer tensor
    157     TensorShape shape;
    158     DataType type;
    159     EXPECT_TRUE(reader.HasTensor("tensor_int", &shape, &type));
    160     TensorShape expected({10});
    161     EXPECT_TRUE(shape.IsSameSize(expected));
    162     EXPECT_EQ(DT_INT32, type);
    163 
    164     // We expect the tensor value to be correct.
    165     TensorSlice s = TensorSlice::ParseOrDie("-");
    166     int data[10];
    167     std::fill_n(data, 10, 0);
    168     EXPECT_TRUE(reader.CopySliceData("tensor_int", s, data));
    169     for (int i = 0; i < 10; ++i) {
    170       EXPECT_EQ(i + 1, data[i]);
    171     }
    172   }
    173 
    174   {
    175     // The 2-d float tensor
    176     TensorShape shape;
    177     DataType type;
    178     EXPECT_TRUE(reader.HasTensor("tensor_float", &shape, &type));
    179     TensorShape expected({2, 4});
    180     EXPECT_TRUE(shape.IsSameSize(expected));
    181     EXPECT_EQ(DT_FLOAT, type);
    182 
    183     // We expect the tensor value to be correct.
    184     TensorSlice s = TensorSlice::ParseOrDie("-:-");
    185     float data[8];
    186     std::fill_n(data, 8, 0);
    187     EXPECT_TRUE(reader.CopySliceData("tensor_float", s, data));
    188     for (int i = 0; i < 8; ++i) {
    189       EXPECT_EQ(static_cast<float>(i) / 10, data[i]);
    190     }
    191   }
    192 
    193   {
    194     // The 2-d double tensor
    195     TensorShape shape;
    196     DataType type;
    197     EXPECT_TRUE(reader.HasTensor("tensor_double", &shape, &type));
    198     TensorShape expected({2, 4});
    199     EXPECT_TRUE(shape.IsSameSize(expected));
    200     EXPECT_EQ(DT_DOUBLE, type);
    201 
    202     // We expect the tensor value to be correct.
    203     TensorSlice s = TensorSlice::ParseOrDie("-:-");
    204     double data[8];
    205     std::fill_n(data, 8, 0);
    206     EXPECT_TRUE(reader.CopySliceData("tensor_double", s, data));
    207     for (int i = 0; i < 8; ++i) {
    208       EXPECT_EQ(static_cast<double>(i) / 20, data[i]);
    209     }
    210   }
    211 
    212   {
    213     // The 2-d qint8 tensor
    214     TensorShape shape;
    215     DataType type;
    216     EXPECT_TRUE(reader.HasTensor("tensor_qint8", &shape, &type));
    217     TensorShape expected({3, 2});
    218     EXPECT_TRUE(shape.IsSameSize(expected));
    219     EXPECT_EQ(DT_QINT8, type);
    220 
    221     // We expect the tensor value to be correct.
    222     TensorSlice s = TensorSlice::ParseOrDie("-:-");
    223     qint8 data[6];
    224     EXPECT_TRUE(reader.CopySliceData("tensor_qint8", s, data));
    225     for (int i = 0; i < 6; ++i) {
    226       EXPECT_EQ(*reinterpret_cast<qint8*>(&i), data[i]);
    227     }
    228   }
    229 
    230   {
    231     // The 2-d qint32 tensor
    232     TensorShape shape;
    233     DataType type;
    234     EXPECT_TRUE(reader.HasTensor("tensor_qint32", &shape, &type));
    235     TensorShape expected({2, 3});
    236     EXPECT_TRUE(shape.IsSameSize(expected));
    237     EXPECT_EQ(DT_QINT32, type);
    238 
    239     // We expect the tensor value to be correct.
    240     TensorSlice s = TensorSlice::ParseOrDie("-:-");
    241     qint32 data[6];
    242     EXPECT_TRUE(reader.CopySliceData("tensor_qint32", s, data));
    243     for (int i = 0; i < 6; ++i) {
    244       EXPECT_EQ(*reinterpret_cast<qint32*>(&i) * qint8(2), data[i]);
    245     }
    246   }
    247 
    248   {
    249     // The 1-d uint8 tensor
    250     TensorShape shape;
    251     DataType type;
    252     EXPECT_TRUE(reader.HasTensor("tensor_uint8", &shape, &type));
    253     TensorShape expected({11});
    254     EXPECT_TRUE(shape.IsSameSize(expected));
    255     EXPECT_EQ(DT_UINT8, type);
    256 
    257     // We expect the tensor value to be correct.
    258     TensorSlice s = TensorSlice::ParseOrDie("-");
    259     uint8 data[11];
    260     EXPECT_TRUE(reader.CopySliceData("tensor_uint8", s, data));
    261     for (int i = 0; i < 11; ++i) {
    262       EXPECT_EQ(i + 1, data[i]);
    263     }
    264   }
    265 
    266   {
    267     // The 1-d int8 tensor
    268     TensorShape shape;
    269     DataType type;
    270     EXPECT_TRUE(reader.HasTensor("tensor_int8", &shape, &type));
    271     TensorShape expected({7});
    272     EXPECT_TRUE(shape.IsSameSize(expected));
    273     EXPECT_EQ(DT_INT8, type);
    274 
    275     // We expect the tensor value to be correct.
    276     TensorSlice s = TensorSlice::ParseOrDie("-");
    277     int8 data[7];
    278     EXPECT_TRUE(reader.CopySliceData("tensor_int8", s, data));
    279     for (int i = 0; i < 7; ++i) {
    280       EXPECT_EQ(i - 7, data[i]);
    281     }
    282   }
    283 
    284   {
    285     // The 1-d int16 tensor
    286     TensorShape shape;
    287     DataType type;
    288     EXPECT_TRUE(reader.HasTensor("tensor_int16", &shape, &type));
    289     TensorShape expected({7});
    290     EXPECT_TRUE(shape.IsSameSize(expected));
    291     EXPECT_EQ(DT_INT16, type);
    292 
    293     // We expect the tensor value to be correct.
    294     TensorSlice s = TensorSlice::ParseOrDie("-");
    295     int16 data[7];
    296     EXPECT_TRUE(reader.CopySliceData("tensor_int16", s, data));
    297     for (int i = 0; i < 7; ++i) {
    298       EXPECT_EQ(i - 8, data[i]);
    299     }
    300   }
    301 
    302   {
    303     // The 1-d int64 tensor
    304     TensorShape shape;
    305     DataType type;
    306     EXPECT_TRUE(reader.HasTensor("tensor_int64", &shape, &type));
    307     TensorShape expected({9});
    308     EXPECT_TRUE(shape.IsSameSize(expected));
    309     EXPECT_EQ(DT_INT64, type);
    310 
    311     // We expect the tensor value to be correct.
    312     TensorSlice s = TensorSlice::ParseOrDie("-");
    313     int64 data[9];
    314     EXPECT_TRUE(reader.CopySliceData("tensor_int64", s, data));
    315     for (int i = 0; i < 9; ++i) {
    316       EXPECT_EQ(i - 9, data[i]);
    317     }
    318   }
    319 
    320   {
    321     // The 1-d string tensor
    322     TensorShape shape;
    323     DataType type;
    324     EXPECT_TRUE(reader.HasTensor("tensor_string", &shape, &type));
    325     TensorShape expected({2});
    326     EXPECT_TRUE(shape.IsSameSize(expected));
    327     EXPECT_EQ(DT_STRING, type);
    328 
    329     // We expect the tensor value to be correct.
    330     TensorSlice s = TensorSlice::ParseOrDie("-");
    331     string data[2];
    332     EXPECT_TRUE(reader.CopySliceData("tensor_string", s, data));
    333     EXPECT_EQ("no", data[0]);
    334     EXPECT_EQ("yes", data[1]);
    335   }
    336 
    337   {
    338     // The 2-d complex64 tensor
    339     TensorShape shape;
    340     DataType type;
    341     EXPECT_TRUE(reader.HasTensor("tensor_complex64", &shape, &type));
    342     TensorShape expected({2, 3});
    343     EXPECT_TRUE(shape.IsSameSize(expected));
    344     EXPECT_EQ(DT_COMPLEX64, type);
    345 
    346     // We expect the tensor value to be correct.
    347     TensorSlice s = TensorSlice::ParseOrDie("-:-");
    348     complex64 data[6];
    349     EXPECT_TRUE(reader.CopySliceData("tensor_complex64", s, data));
    350     for (int i = 0; i < 6; ++i) {
    351       EXPECT_EQ(100 + i, data[i].real());
    352       EXPECT_EQ(200 + i, data[i].imag());
    353     }
    354   }
    355 
    356   {
    357     // The 2-d complex128 tensor
    358     TensorShape shape;
    359     DataType type;
    360     EXPECT_TRUE(reader.HasTensor("tensor_complex128", &shape, &type));
    361     TensorShape expected({2, 3});
    362     EXPECT_TRUE(shape.IsSameSize(expected));
    363     EXPECT_EQ(DT_COMPLEX128, type);
    364 
    365     // We expect the tensor value to be correct.
    366     TensorSlice s = TensorSlice::ParseOrDie("-:-");
    367     complex128 data[6];
    368     EXPECT_TRUE(reader.CopySliceData("tensor_complex128", s, data));
    369     for (int i = 0; i < 6; ++i) {
    370       EXPECT_EQ(100 + i, data[i].real());
    371       EXPECT_EQ(200 + i, data[i].imag());
    372     }
    373   }
    374   {
    375     // The 2-d half tensor
    376     TensorShape shape;
    377     DataType type;
    378     EXPECT_TRUE(reader.HasTensor("tensor_half", &shape, &type));
    379     TensorShape expected({2, 4});
    380     EXPECT_TRUE(shape.IsSameSize(expected));
    381     EXPECT_EQ(DT_HALF, type);
    382 
    383     // We expect the tensor value to be correct.
    384     TensorSlice s = TensorSlice::ParseOrDie("-:-");
    385     Eigen::half data[8];
    386     std::fill_n(data, 8, Eigen::half(0));
    387     EXPECT_TRUE(reader.CopySliceData("tensor_half", s, data));
    388     for (int i = 0; i < 8; ++i) {
    389       EXPECT_EQ(static_cast<Eigen::half>(i) / Eigen::half(2), data[i]);
    390     }
    391   }
    392 }
    393 
    394 class SaveSlicesOpTest : public OpsTestBase {
    395  protected:
    396   void MakeOp() {
    397     TF_ASSERT_OK(NodeDefBuilder("myop", "SaveSlices")
    398                      .Input(FakeInput())
    399                      .Input(FakeInput())
    400                      .Input(FakeInput())
    401                      .Input(FakeInput(
    402                          {DT_INT32, DT_FLOAT, DT_DOUBLE, DT_QINT8, DT_QINT32}))
    403                      .Finalize(node_def()));
    404     TF_ASSERT_OK(InitOp());
    405   }
    406 };
    407 
    408 // Here we save only slices.  We restore them in a larger tensor and we check
    409 // that the right slice is restored.  It is quite tricky to check that the
    410 // right slices are actually restored so instead we just check that
    411 // CopySliceData() return true/false depending on the slice we ask for.
    412 TEST_F(SaveSlicesOpTest, Slices) {
    413   const string filename = io::JoinPath(testing::TmpDir(), "tensor_slices");
    414   const string tensornames[] = {"tensor_int", "tensor_float", "tensor_double",
    415                                 "tensor_qint8", "tensor_qint32"};
    416   // Specifies that the data we save are slices of larger tensors.
    417   // See core/framework/tensor_slice.h for the slice syntax.
    418   const string tensorshapes[] = {
    419       "10 -",         // Full contents of a 10 element vector.
    420       "2 4 -:0,2",    // A 2x2 slice of a 2x4 tensor.
    421       "2 4 0,1:2,2",  // A 1x2 slice of a 2x4 tensor.
    422       "3 2 -:-",      // Full contents of a 3x2 tensor.
    423       "2 3 1,1:2,1"   // Another 1x1 slice of a2x3 tensor.
    424   };
    425 
    426   MakeOp();
    427   // Add a file name
    428   AddInput<string>(TensorShape({}),
    429                    [&filename](int x) -> string { return filename; });
    430 
    431   // Add the tensor names
    432   AddInput<string>(TensorShape({5}),
    433                    [&tensornames](int x) -> string { return tensornames[x]; });
    434 
    435   // Add the tensor shapes and slices
    436   AddInput<string>(TensorShape({5}), [&tensorshapes](int x) -> string {
    437     return tensorshapes[x];
    438   });
    439 
    440   // Add a 1-d integer tensor
    441   AddInput<int32>(TensorShape({10}), [](int x) -> int32 { return x + 1; });
    442 
    443   // Add a 2-d float tensor
    444   AddInput<float>(TensorShape({2, 2}),
    445                   [](int x) -> float { return static_cast<float>(x) / 10; });
    446 
    447   // Add a 2-d double tensor
    448   AddInput<double>(TensorShape({1, 2}),
    449                    [](int x) -> double { return static_cast<double>(x) / 20; });
    450 
    451   // Add a 2-d qint8 tensor
    452   AddInput<qint8>(TensorShape({3, 2}),
    453                   [](int x) -> qint8 { return *reinterpret_cast<qint8*>(&x); });
    454 
    455   // Add a 2-d qint32 tensor
    456   AddInput<qint32>(TensorShape({1, 1}), [](int x) -> qint32 {
    457     return *reinterpret_cast<qint32*>(&x) * qint8(2);
    458   });
    459 
    460   TF_ASSERT_OK(RunOpKernel());
    461 
    462   // Check that the checkpoint file is properly written
    463   checkpoint::TensorSliceReader reader(filename,
    464                                        checkpoint::OpenTableTensorSliceReader);
    465   TF_EXPECT_OK(reader.status());
    466 
    467   // We expect to find all saved tensors
    468   {
    469     // The 1-d integer tensor
    470     TensorShape shape;
    471     DataType type;
    472     EXPECT_TRUE(reader.HasTensor("tensor_int", &shape, &type));
    473     TensorShape expected({10});
    474     EXPECT_TRUE(shape.IsSameSize(expected));
    475     EXPECT_EQ(DT_INT32, type);
    476 
    477     // We saved the full tensor so we should be able to read it all.
    478     TensorSlice s = TensorSlice::ParseOrDie("-");
    479     int data[10];
    480     EXPECT_TRUE(reader.CopySliceData("tensor_int", s, data));
    481   }
    482 
    483   {
    484     // The 2-d float tensor
    485     TensorShape shape;
    486     DataType type;
    487     EXPECT_TRUE(reader.HasTensor("tensor_float", &shape, &type));
    488     TensorShape expected({2, 4});
    489     EXPECT_TRUE(shape.IsSameSize(expected));
    490     EXPECT_EQ(DT_FLOAT, type);
    491 
    492     // We saved the slice "-:0,2" so we should not be able to read the full
    493     // tensor.
    494     TensorSlice full_slice = TensorSlice::ParseOrDie("-:-");
    495     TensorSlice saved_slice = TensorSlice::ParseOrDie("-:0,2");
    496     float data[8];
    497     EXPECT_FALSE(reader.CopySliceData("tensor_float", full_slice, data));
    498     EXPECT_TRUE(reader.CopySliceData("tensor_float", saved_slice, data));
    499   }
    500 
    501   {
    502     // The 2-d double tensor
    503     TensorShape shape;
    504     DataType type;
    505     EXPECT_TRUE(reader.HasTensor("tensor_double", &shape, &type));
    506     TensorShape expected({2, 4});
    507     EXPECT_TRUE(shape.IsSameSize(expected));
    508     EXPECT_EQ(DT_DOUBLE, type);
    509 
    510     // We saved the slice "0,1:2,2" so we should not be able to read the full
    511     // tensor.
    512     TensorSlice full_slice = TensorSlice::ParseOrDie("-:-");
    513     TensorSlice saved_slice = TensorSlice::ParseOrDie("0,1:2,2");
    514     double data[8];
    515     EXPECT_FALSE(reader.CopySliceData("tensor_double", full_slice, data));
    516     EXPECT_TRUE(reader.CopySliceData("tensor_double", saved_slice, data));
    517   }
    518 
    519   {
    520     // The 2-d qint8 tensor
    521     TensorShape shape;
    522     DataType type;
    523     EXPECT_TRUE(reader.HasTensor("tensor_qint8", &shape, &type));
    524     TensorShape expected({3, 2});
    525     EXPECT_TRUE(shape.IsSameSize(expected));
    526     EXPECT_EQ(DT_QINT8, type);
    527 
    528     // We saved the full slice.
    529     TensorSlice s = TensorSlice::ParseOrDie("-:-");
    530     qint8 data[6];
    531     EXPECT_TRUE(reader.CopySliceData("tensor_qint8", s, data));
    532   }
    533 
    534   {
    535     // The 2-d qint32 tensor
    536     TensorShape shape;
    537     DataType type;
    538     EXPECT_TRUE(reader.HasTensor("tensor_qint32", &shape, &type));
    539     TensorShape expected({2, 3});
    540     EXPECT_TRUE(shape.IsSameSize(expected));
    541     EXPECT_EQ(DT_QINT32, type);
    542 
    543     // We expect the tensor value to be correct.
    544     TensorSlice s = TensorSlice::ParseOrDie("1,1:2,1");
    545     TensorSlice full_slice = TensorSlice::ParseOrDie("-:-");
    546     TensorSlice saved_slice = TensorSlice::ParseOrDie("1,1:2,1");
    547     qint32 data[6];
    548     EXPECT_FALSE(reader.CopySliceData("tensor_qint32", full_slice, data));
    549     EXPECT_TRUE(reader.CopySliceData("tensor_qint32", saved_slice, data));
    550   }
    551 }
    552 
    553 class SaveOpSlices2Test : public OpsTestBase {
    554  protected:
    555   void MakeOp() {
    556     TF_ASSERT_OK(NodeDefBuilder("myop", "SaveSlices")
    557                      .Input(FakeInput())
    558                      .Input(FakeInput())
    559                      .Input(FakeInput())
    560                      .Input(FakeInput({DT_INT32, DT_INT32, DT_FLOAT}))
    561                      .Finalize(node_def()));
    562     TF_ASSERT_OK(InitOp());
    563   }
    564 };
    565 
    566 TEST_F(SaveOpSlices2Test, TwoSlices) {
    567   const string filename = io::JoinPath(testing::TmpDir(), "three_slices");
    568   // We will save 2 slices of the tensor named "four_by_sixteen" which is 4x16,
    569   // and one slice of the "small" tensor.
    570   const string tensornames[] = {"four_by_sixteen", "four_by_sixteen", "small"};
    571   const string tensorshapes[] = {
    572       // Slice specifications for the 2 slices of "four_by_sixteen"
    573       "4 16 0,2:-",  // 1st slice covers indices 0 and 1 in the first dim.
    574       "4 16 2,2:-",  // 2nd slice covers indices 2 and 3 in the first dim.
    575       ""             // We save the full "small" tensors.
    576   };
    577 
    578   MakeOp();
    579   // Add a file name
    580   AddInput<string>(TensorShape({}),
    581                    [&filename](int x) -> string { return filename; });
    582 
    583   // Add the tensor names
    584   AddInput<string>(TensorShape({3}),
    585                    [&tensornames](int x) -> string { return tensornames[x]; });
    586 
    587   // Add the tensor shapes and slices
    588   AddInput<string>(TensorShape({3}), [&tensorshapes](int x) -> string {
    589     return tensorshapes[x];
    590   });
    591 
    592   // Add an integer tensor for slice 0,2:- of a 4x16 tensor: It is 2x16.
    593   AddInput<int32>(TensorShape({2, 16}), [](int x) -> int32 { return x + 1; });
    594 
    595   // Add an integer tensor for slice 2,2:- of a 4x16 tensor: It is 2x16.
    596   AddInput<int32>(TensorShape({2, 16}),
    597                   [](int x) -> int32 { return 10 * (x + 1); });
    598 
    599   // Add a float tensor for "small"
    600   AddInput<float>(TensorShape({2, 4}),
    601                   [](int x) -> float { return static_cast<float>(x) / 10; });
    602 
    603   TF_ASSERT_OK(RunOpKernel());
    604 
    605   // Check that the checkpoint file is properly written
    606   checkpoint::TensorSliceReader reader(filename,
    607                                        checkpoint::OpenTableTensorSliceReader);
    608   TF_EXPECT_OK(reader.status());
    609 
    610   {
    611     // Reload the two slices of "four_by_sixteen" into that tensor.
    612     Tensor reloaded(DT_INT32, {4, 16});
    613 
    614     // We expect to find all slices
    615     TensorShape shape;
    616     DataType type;
    617     EXPECT_TRUE(reader.HasTensor("four_by_sixteen", &shape, &type));
    618     EXPECT_TRUE(shape.IsSameSize(reloaded.shape()));
    619     EXPECT_EQ(type, reloaded.dtype());
    620 
    621     // Reload the whole tensor.
    622     EXPECT_TRUE(reader.CopySliceData("four_by_sixteen",
    623                                      TensorSlice(reloaded.dims()),
    624                                      reloaded.flat<int>().data()));
    625 
    626     {
    627       auto slice = reloaded.Slice(0, 2).flat<int>();
    628       for (int i = 0; i < slice.size(); ++i) {
    629         EXPECT_EQ(i + 1, slice(i));
    630       }
    631     }
    632     {
    633       auto slice = reloaded.Slice(2, 4).flat<int>();
    634       for (int i = 0; i < slice.size(); ++i) {
    635         EXPECT_EQ(10 * (i + 1), slice(i));
    636       }
    637     }
    638   }
    639 
    640   {
    641     // Reload the small float tensor.
    642     Tensor reloaded(DT_FLOAT, {2, 4});
    643 
    644     TensorShape shape;
    645     DataType type;
    646     EXPECT_TRUE(reader.HasTensor("small", &shape, &type));
    647     EXPECT_TRUE(shape.IsSameSize(reloaded.shape()));
    648     EXPECT_EQ(DT_FLOAT, reloaded.dtype());
    649 
    650     EXPECT_TRUE(reader.CopySliceData("small", TensorSlice(reloaded.dims()),
    651                                      reloaded.flat<float>().data()));
    652 
    653     for (int64 i = 0; i < reloaded.NumElements(); ++i) {
    654       EXPECT_EQ(static_cast<float>(i) / 10, reloaded.flat<float>().data()[i]);
    655     }
    656   }
    657 }
    658 
    659 // Benchmark-related code below.
    660 
    661 static void BM_LargeTensorWrite(int iters, int num_elements) {
    662   testing::StopTiming();
    663 
    664   // 4 * num_elements bytes total , since sizeof(float) == 4.
    665   Tensor tensor(DT_FLOAT, TensorShape({num_elements}));
    666   tensor.flat<float>().setZero();
    667 
    668   // Builds the graph.
    669   const string temp_filename =
    670       io::JoinPath(testing::TmpDir(), "benchmark_checkpoint");
    671   auto root = Scope::NewRootScope().ExitOnError();
    672   const string tensor_name = "my_tensor";
    673   ops::Save(root, temp_filename, {tensor_name}, {{tensor}});
    674 
    675   // Disables optimizations.
    676   SessionOptions session_options;
    677   session_options.config.mutable_graph_options()
    678       ->mutable_optimizer_options()
    679       ->set_opt_level(tensorflow::OptimizerOptions_Level_L0);
    680 
    681   TF_CHECK_OK(root.status());
    682   Graph* g = new Graph(OpRegistry::Global());
    683   TF_CHECK_OK(root.ToGraph(g));
    684   VLOG(1) << "Save op's output path: " << temp_filename;
    685   VLOG(1) << "# nodes in Graph: " << g->num_nodes();
    686 
    687   testing::StartTiming();
    688   test::Benchmark("cpu", g, &session_options).Run(iters);
    689 }
    690 BENCHMARK(BM_LargeTensorWrite)->Arg((1 << 30) / 4 /* 1GB float tensor */);
    691 
    692 }  // namespace
    693 }  // namespace tensorflow
    694