Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7 http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <complex>
     17 #include <functional>
     18 #include <memory>
     19 #include <string>
     20 #include <vector>
     21 
     22 #include "tensorflow/core/common_runtime/device.h"
     23 #include "tensorflow/core/common_runtime/device_factory.h"
     24 #include "tensorflow/core/framework/allocator.h"
     25 #include "tensorflow/core/framework/fake_input.h"
     26 #include "tensorflow/core/framework/node_def_builder.h"
     27 #include "tensorflow/core/framework/op_kernel.h"
     28 #include "tensorflow/core/framework/tensor.h"
     29 #include "tensorflow/core/framework/types.h"
     30 #include "tensorflow/core/framework/types.pb.h"
     31 #include "tensorflow/core/kernels/ops_testutil.h"
     32 #include "tensorflow/core/lib/io/path.h"
     33 #include "tensorflow/core/platform/test.h"
     34 
     35 namespace tensorflow {
     36 namespace {
     37 
     38 // Make an input tensor with filled results.
     39 template <typename T>
     40 Tensor MakeInput(const TensorShape& shape,
     41                  std::function<T(int)> input_mapping) {
     42   Tensor input(DataTypeToEnum<T>::v(), shape);
     43   test::FillFn(&input, input_mapping);
     44   return input;
     45 }
     46 
     47 class RestoreV2OpTest : public OpsTestBase {
     48  protected:
     49   // Makes an operation to restore two tensors
     50   void MakeRestoreOp(DataType dt) {
     51     TF_ASSERT_OK(NodeDefBuilder("myop", "RestoreV2")
     52                      .Input(FakeInput())    // prefix
     53                      .Input(FakeInput())    // tensor_names
     54                      .Input(FakeInput())    // shape_and_slices
     55                      .Attr("dtypes", {dt})  // dtypes
     56                      .Finalize(node_def()));
     57     TF_ASSERT_OK(InitOp());
     58   }
     59 
     60   void RunTest(StringPiece save_op_to_use) {
     61     const string filename =
     62         io::JoinPath(testing::TmpDir(), "tensor_simple-", save_op_to_use);
     63     const std::vector<string> tensor_names = {
     64         "tensor_bool",  "tensor_int",    "tensor_float",     "tensor_double",
     65         "tensor_qint8", "tensor_qint32", "tensor_uint8",     "tensor_int8",
     66         "tensor_int16", "tensor_int64",  "tensor_complex64", "tensor_half"};
     67 
     68     // We first need to write using the desired save op.
     69     {
     70       // Initialize an operation.
     71       NodeDef save;
     72       if (save_op_to_use != "Save") {
     73         TF_ASSERT_OK(
     74             NodeDefBuilder("myop", save_op_to_use)
     75                 .Input(FakeInput())  // prefix
     76                 .Input(FakeInput())  // tensor_names
     77                 .Input(FakeInput())  // shape_and_slices
     78                 .Input(FakeInput({DT_BOOL, DT_INT32, DT_FLOAT, DT_DOUBLE,
     79                                   DT_QINT8, DT_QINT32, DT_UINT8, DT_INT8,
     80                                   DT_INT16, DT_COMPLEX64, DT_HALF}))  // tensors
     81                 .Finalize(&save));
     82       } else {
     83         TF_ASSERT_OK(
     84             NodeDefBuilder("myop", save_op_to_use)
     85                 .Input(FakeInput())  // file
     86                 .Input(FakeInput())  // tensor_names
     87                 .Input(FakeInput({DT_BOOL, DT_INT32, DT_FLOAT, DT_DOUBLE,
     88                                   DT_QINT8, DT_QINT32, DT_UINT8, DT_INT8,
     89                                   DT_INT16, DT_COMPLEX64, DT_HALF}))  // tensors
     90                 .Finalize(&save));
     91       }
     92 
     93       std::unique_ptr<Device> device(
     94           DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
     95 
     96       gtl::InlinedVector<TensorValue, 4> inputs;
     97 
     98       Status status;
     99       std::unique_ptr<OpKernel> op(
    100           CreateOpKernel(DEVICE_CPU, device.get(), cpu_allocator(), save,
    101                          TF_GRAPH_DEF_VERSION, &status));
    102       TF_EXPECT_OK(status);
    103 
    104       // Run it
    105 
    106       // Input #0 is the file name
    107       Tensor input_0(DT_STRING, TensorShape({}));
    108       input_0.scalar<string>()() = filename;
    109       inputs.push_back({nullptr, &input_0});
    110 
    111       // Input #1 is the tensor names
    112       Tensor input_1 = MakeInput<string>(
    113           TensorShape({static_cast<int>(tensor_names.size())}),
    114           [&tensor_names](int x) -> string { return tensor_names[x]; });
    115       inputs.push_back({nullptr, &input_1});
    116 
    117       Tensor shape_and_slices = MakeInput<string>(
    118           TensorShape({static_cast<int>(tensor_names.size())}),
    119           [](int x) -> string { return "" /* saves in full */; });
    120       if (save_op_to_use != "Save") {
    121         inputs.push_back({nullptr, &shape_and_slices});
    122       }
    123 
    124       // Input #2 is a 1-d bool tensor
    125       Tensor input_2 = MakeInput<bool>(TensorShape({2}),
    126                                        [](int x) -> bool { return x != 0; });
    127       inputs.push_back({nullptr, &input_2});
    128       // Input #3 is a 1-d integer tensor
    129       Tensor input_3 = MakeInput<int32>(TensorShape({10}),
    130                                         [](int x) -> int32 { return x + 1; });
    131       inputs.push_back({nullptr, &input_3});
    132       // Input #4 is a 2-d float tensor
    133       Tensor input_4 = MakeInput<float>(
    134           TensorShape({2, 4}),
    135           [](int x) -> float { return static_cast<float>(x) / 10; });
    136       inputs.push_back({nullptr, &input_4});
    137       // Input #5 is a 2-d double tensor
    138       Tensor input_5 = MakeInput<double>(
    139           TensorShape({2, 4}),
    140           [](int x) -> double { return static_cast<double>(x) / 20; });
    141       inputs.push_back({nullptr, &input_5});
    142       // Input #6 is a 2-d qint8 tensor
    143       Tensor input_6 = MakeInput<qint8>(
    144           TensorShape({3, 2}),
    145           [](int x) -> qint8 { return *reinterpret_cast<qint8*>(&x); });
    146       inputs.push_back({nullptr, &input_6});
    147       // Input #7 is a 2-d qint32 tensor
    148       Tensor input_7 =
    149           MakeInput<qint32>(TensorShape({2, 3}), [](int x) -> qint32 {
    150             return *reinterpret_cast<qint32*>(&x) * qint8(2);
    151           });
    152       inputs.push_back({nullptr, &input_7});
    153       // Input #8 is a 1-d uint8 tensor
    154       Tensor input_8 = MakeInput<uint8>(TensorShape({11}),
    155                                         [](int x) -> uint8 { return x + 1; });
    156       inputs.push_back({nullptr, &input_8});
    157       // Input #9 is a 1-d int8 tensor
    158       Tensor input_9 = MakeInput<int8>(TensorShape({7}),
    159                                        [](int x) -> int8 { return x - 7; });
    160       inputs.push_back({nullptr, &input_9});
    161       // Input #10 is a 1-d int16 tensor
    162       Tensor input_10 = MakeInput<int16>(TensorShape({7}),
    163                                          [](int x) -> int16 { return x - 8; });
    164       inputs.push_back({nullptr, &input_10});
    165       // Input #11 is a 1-d int64 tensor
    166       Tensor input_11 = MakeInput<int64>(TensorShape({9}),
    167                                          [](int x) -> int64 { return x - 9; });
    168       inputs.push_back({nullptr, &input_11});
    169       // Input #12 is a 1-d complex64 tensor
    170       Tensor input_13 = MakeInput<complex64>(
    171           TensorShape({2, 3}),
    172           [](int x) -> complex64 { return complex64(100 + x, 200 + x); });
    173       inputs.push_back({nullptr, &input_13});
    174       // Input #13 is a 2-d half tensor
    175       Tensor input_14 =
    176           MakeInput<Eigen::half>(TensorShape({2, 4}), [](int x) -> Eigen::half {
    177             return static_cast<Eigen::half>(x) / Eigen::half(5);
    178           });
    179       inputs.push_back({nullptr, &input_14});
    180       OpKernelContext::Params params;
    181       params.device = device.get();
    182       params.frame_iter = FrameAndIter(0, 0);
    183       params.inputs = &inputs;
    184       params.op_kernel = op.get();
    185       std::vector<AllocatorAttributes> attrs;
    186       test::SetOutputAttrs(&params, &attrs);
    187 
    188       OpKernelContext ctx(&params);
    189       op->Compute(&ctx);
    190       TF_EXPECT_OK(ctx.status());
    191     }
    192 
    193     // Now we restore
    194 
    195     // The 1-d bool tensor
    196     {
    197       MakeRestoreOp(DT_BOOL);
    198       AddInput<string>(TensorShape({}),
    199                        [&filename](int x) -> string { return filename; });
    200       AddInput<string>(TensorShape({1}),
    201                        [&](int x) -> string { return tensor_names[0]; });
    202       AddInput<string>(TensorShape({1}), [&](int x) -> string {
    203         return "";
    204       });  // Restores in full.
    205       TF_ASSERT_OK(RunOpKernel());
    206       Tensor* output = GetOutput(0);
    207       TensorShape expected({2});
    208       EXPECT_TRUE(output->shape().IsSameSize(expected));
    209       for (int i = 0; i < 2; ++i) {
    210         EXPECT_EQ(i != 0, output->flat<bool>()(i));
    211       }
    212     }
    213     // The 1-d integer tensor
    214     {
    215       MakeRestoreOp(DT_INT32);
    216       (*mutable_input(1).tensor).flat<string>()(0) = tensor_names[1];
    217       TF_ASSERT_OK(RunOpKernel());
    218       Tensor* output = GetOutput(0);
    219       TensorShape expected({10});
    220       EXPECT_TRUE(output->shape().IsSameSize(expected));
    221       for (int i = 0; i < 10; ++i) {
    222         EXPECT_EQ(i + 1, output->flat<int32>()(i));
    223       }
    224     }
    225     // The 2-d float tensor
    226     {
    227       MakeRestoreOp(DT_FLOAT);
    228       (*mutable_input(1).tensor).flat<string>()(0) = tensor_names[2];
    229       TF_ASSERT_OK(RunOpKernel());
    230       Tensor* output = GetOutput(0);
    231       TensorShape expected({2, 4});
    232       EXPECT_TRUE(output->shape().IsSameSize(expected));
    233       for (int i = 0; i < 8; ++i) {
    234         EXPECT_EQ(static_cast<float>(i) / 10, output->flat<float>()(i));
    235       }
    236     }
    237     // The 2-d double tensor
    238     {
    239       MakeRestoreOp(DT_DOUBLE);
    240       (*mutable_input(1).tensor).flat<string>()(0) = tensor_names[3];
    241       TF_ASSERT_OK(RunOpKernel());
    242       Tensor* output = GetOutput(0);
    243       TensorShape expected({2, 4});
    244       EXPECT_TRUE(output->shape().IsSameSize(expected));
    245       for (int i = 0; i < 8; ++i) {
    246         EXPECT_EQ(static_cast<double>(i) / 20, output->flat<double>()(i));
    247       }
    248     }
    249     // The 2-d qint8 tensor
    250     {
    251       MakeRestoreOp(DT_QINT8);
    252       (*mutable_input(1).tensor).flat<string>()(0) = tensor_names[4];
    253       TF_ASSERT_OK(RunOpKernel());
    254       Tensor* output = GetOutput(0);
    255       TensorShape expected({3, 2});
    256       EXPECT_TRUE(output->shape().IsSameSize(expected));
    257       for (int i = 0; i < 6; ++i) {
    258         EXPECT_EQ(*reinterpret_cast<qint8*>(&i), output->flat<qint8>()(i));
    259       }
    260     }
    261     // The 2-d qint32 tensor
    262     {
    263       MakeRestoreOp(DT_QINT32);
    264       (*mutable_input(1).tensor).flat<string>()(0) = tensor_names[5];
    265       TF_ASSERT_OK(RunOpKernel());
    266       Tensor* output = GetOutput(0);
    267       TensorShape expected({2, 3});
    268       EXPECT_TRUE(output->shape().IsSameSize(expected));
    269       for (int i = 0; i < 6; ++i) {
    270         EXPECT_EQ(*reinterpret_cast<qint32*>(&i) * qint8(2),
    271                   output->flat<qint32>()(i));
    272       }
    273     }
    274     // The 1-d uint8 tensor
    275     {
    276       MakeRestoreOp(DT_UINT8);
    277       (*mutable_input(1).tensor).flat<string>()(0) = tensor_names[6];
    278       TF_ASSERT_OK(RunOpKernel());
    279       Tensor* output = GetOutput(0);
    280       TensorShape expected({11});
    281       EXPECT_TRUE(output->shape().IsSameSize(expected));
    282       for (int i = 0; i < 11; ++i) {
    283         EXPECT_EQ(i + 1, output->flat<uint8>()(i));
    284       }
    285     }
    286     // The 1-d int8 tensor
    287     {
    288       MakeRestoreOp(DT_INT8);
    289       (*mutable_input(1).tensor).flat<string>()(0) = tensor_names[7];
    290       TF_ASSERT_OK(RunOpKernel());
    291       Tensor* output = GetOutput(0);
    292       TensorShape expected({7});
    293       EXPECT_TRUE(output->shape().IsSameSize(expected));
    294       for (int i = 0; i < 7; ++i) {
    295         EXPECT_EQ(i - 7, output->flat<int8>()(i));
    296       }
    297     }
    298     // The 1-d int16 tensor
    299     {
    300       MakeRestoreOp(DT_INT16);
    301       (*mutable_input(1).tensor).flat<string>()(0) = tensor_names[8];
    302       TF_ASSERT_OK(RunOpKernel());
    303       Tensor* output = GetOutput(0);
    304       TensorShape expected({7});
    305       EXPECT_TRUE(output->shape().IsSameSize(expected));
    306       for (int i = 0; i < 7; ++i) {
    307         EXPECT_EQ(i - 8, output->flat<int16>()(i));
    308       }
    309     }
    310     // The 1-d int64 tensor
    311     {
    312       MakeRestoreOp(DT_INT64);
    313       (*mutable_input(1).tensor).flat<string>()(0) = tensor_names[9];
    314       TF_ASSERT_OK(RunOpKernel());
    315       Tensor* output = GetOutput(0);
    316       TensorShape expected({9});
    317       EXPECT_TRUE(output->shape().IsSameSize(expected));
    318       for (int i = 0; i < 9; ++i) {
    319         EXPECT_EQ(i - 9, output->flat<int64>()(i));
    320       }
    321     }
    322     // The 2-d complex64 tensor
    323     {
    324       MakeRestoreOp(DT_COMPLEX64);
    325       (*mutable_input(1).tensor).flat<string>()(0) = tensor_names[10];
    326       TF_ASSERT_OK(RunOpKernel());
    327       Tensor* output = GetOutput(0);
    328       TensorShape expected({2, 3});
    329       EXPECT_TRUE(output->shape().IsSameSize(expected));
    330       for (int i = 0; i < 6; ++i) {
    331         EXPECT_EQ(complex64(100 + i, 200 + i), output->flat<complex64>()(i));
    332       }
    333     }
    334     // The 2-d half tensor
    335     {
    336       MakeRestoreOp(DT_HALF);
    337       (*mutable_input(1).tensor).flat<string>()(0) = tensor_names[11];
    338       TF_ASSERT_OK(RunOpKernel());
    339       Tensor* output = GetOutput(0);
    340       TensorShape expected({2, 4});
    341       EXPECT_TRUE(output->shape().IsSameSize(expected));
    342       for (int i = 0; i < 8; ++i) {
    343         EXPECT_EQ(static_cast<Eigen::half>(i) / Eigen::half(5),
    344                   output->flat<Eigen::half>()(i));
    345       }
    346     }
    347   }
    348 };
    349 
    350 // The intended use case (write in V2, read in V2).
    351 TEST_F(RestoreV2OpTest, RestoreAfterSaveV2) { RunTest("SaveV2"); }
    352 // For backward compatibility.
    353 TEST_F(RestoreV2OpTest, RestoreAfterSaveSlicesV1) { RunTest("SaveSlices"); }
    354 TEST_F(RestoreV2OpTest, RestoreAfterSaveV1) { RunTest("Save"); }
    355 
    356 }  // namespace
    357 }  // namespace tensorflow
    358