Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <functional>
     17 #include <memory>
     18 #include <vector>
     19 
     20 #include "tensorflow/core/common_runtime/device.h"
     21 #include "tensorflow/core/common_runtime/device_factory.h"
     22 #include "tensorflow/core/framework/allocator.h"
     23 #include "tensorflow/core/framework/fake_input.h"
     24 #include "tensorflow/core/framework/node_def_builder.h"
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 #include "tensorflow/core/framework/tensor.h"
     27 #include "tensorflow/core/framework/types.h"
     28 #include "tensorflow/core/framework/types.pb.h"
     29 #include "tensorflow/core/kernels/ops_testutil.h"
     30 #include "tensorflow/core/kernels/ops_util.h"
     31 #include "tensorflow/core/lib/io/path.h"
     32 #include "tensorflow/core/lib/strings/strcat.h"
     33 #include "tensorflow/core/platform/test.h"
     34 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
     35 
     36 namespace tensorflow {
     37 namespace {
     38 
     39 class RestoreOpTest : public OpsTestBase {
     40  protected:
     41   // Makes an operation to restore two tensors
     42   void MakeRestoreOp(DataType dt) {
     43     TF_ASSERT_OK(NodeDefBuilder("myop", "Restore")
     44                      .Input(FakeInput())
     45                      .Input(FakeInput())
     46                      .Attr("dt", dt)
     47                      .Finalize(node_def()));
     48     TF_ASSERT_OK(InitOp());
     49   }
     50 };
     51 
     52 // Make an input tensor with filled results.
     53 template <typename T>
     54 Tensor MakeInput(const TensorShape& shape,
     55                  std::function<T(int)> input_mapping) {
     56   Tensor input(DataTypeToEnum<T>::v(), shape);
     57   test::FillFn(&input, input_mapping);
     58   return input;
     59 }
     60 
     61 TEST_F(RestoreOpTest, RestoreSimple) {
     62   const string filename = io::JoinPath(testing::TmpDir(), "tensor_simple");
     63   const std::vector<string> tensor_names = {
     64       "tensor_bool",  "tensor_int",        "tensor_float",  "tensor_double",
     65       "tensor_qint8", "tensor_qint32",     "tensor_uint8",  "tensor_int8",
     66       "tensor_int16", "tensor_int64",      "tensor_string", "tensor_complex64",
     67       "tensor_half",  "tensor_float_empty"};
     68 
     69   // We first need to write a tensor using the save_op
     70   {
     71     // Initialize an operation
     72     NodeDef save;
     73     TF_ASSERT_OK(
     74         NodeDefBuilder("myop", "Save")
     75             .Input(FakeInput())
     76             .Input(FakeInput())
     77             .Input(FakeInput({DT_BOOL, DT_INT32, DT_FLOAT, DT_DOUBLE, DT_QINT8,
     78                               DT_QINT32, DT_UINT8, DT_INT8, DT_INT16, DT_STRING,
     79                               DT_COMPLEX64, DT_HALF}))
     80             .Finalize(&save));
     81 
     82     std::unique_ptr<Device> device(
     83         DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
     84 
     85     gtl::InlinedVector<TensorValue, 4> inputs;
     86 
     87     Status status;
     88     std::unique_ptr<OpKernel> op(CreateOpKernel(DEVICE_CPU, device.get(),
     89                                                 cpu_allocator(), save,
     90                                                 TF_GRAPH_DEF_VERSION, &status));
     91     TF_EXPECT_OK(status);
     92 
     93     // Run it
     94 
     95     // Input #0 is the file name
     96     Tensor input_0(DT_STRING, TensorShape({}));
     97     input_0.scalar<string>()() = filename;
     98     inputs.push_back({nullptr, &input_0});
     99 
    100     // Input #1 is the tensor names
    101     Tensor input_1 = MakeInput<string>(
    102         TensorShape({static_cast<int>(tensor_names.size())}),
    103         [&tensor_names](int x) -> string { return tensor_names[x]; });
    104     inputs.push_back({nullptr, &input_1});
    105 
    106     // Input #2 is a 1-d bool tensor
    107     Tensor input_2 =
    108         MakeInput<bool>(TensorShape({2}), [](int x) -> bool { return x != 0; });
    109     inputs.push_back({nullptr, &input_2});
    110     // Input #3 is a 1-d integer tensor
    111     Tensor input_3 = MakeInput<int32>(TensorShape({10}),
    112                                       [](int x) -> int32 { return x + 1; });
    113     inputs.push_back({nullptr, &input_3});
    114     // Input #4 is a 2-d float tensor
    115     Tensor input_4 = MakeInput<float>(TensorShape({2, 4}), [](int x) -> float {
    116       return static_cast<float>(x) / 10;
    117     });
    118     inputs.push_back({nullptr, &input_4});
    119     // Input #5 is a 2-d double tensor
    120     Tensor input_5 = MakeInput<double>(
    121         TensorShape({2, 4}),
    122         [](int x) -> double { return static_cast<double>(x) / 20; });
    123     inputs.push_back({nullptr, &input_5});
    124     // Input #6 is a 2-d qint8 tensor
    125     Tensor input_6 = MakeInput<qint8>(TensorShape({3, 2}), [](int x) -> qint8 {
    126       return *reinterpret_cast<qint8*>(&x);
    127     });
    128     inputs.push_back({nullptr, &input_6});
    129     // Input #7 is a 2-d qint32 tensor
    130     Tensor input_7 =
    131         MakeInput<qint32>(TensorShape({2, 3}), [](int x) -> qint32 {
    132           return *reinterpret_cast<qint32*>(&x) * qint8(2);
    133         });
    134     inputs.push_back({nullptr, &input_7});
    135     // Input #8 is a 1-d uint8 tensor
    136     Tensor input_8 = MakeInput<uint8>(TensorShape({11}),
    137                                       [](int x) -> uint8 { return x + 1; });
    138     inputs.push_back({nullptr, &input_8});
    139     // Input #9 is a 1-d int8 tensor
    140     Tensor input_9 =
    141         MakeInput<int8>(TensorShape({7}), [](int x) -> int8 { return x - 7; });
    142     inputs.push_back({nullptr, &input_9});
    143     // Input #10 is a 1-d int16 tensor
    144     Tensor input_10 = MakeInput<int16>(TensorShape({7}),
    145                                        [](int x) -> int16 { return x - 8; });
    146     inputs.push_back({nullptr, &input_10});
    147     // Input #11 is a 1-d int64 tensor
    148     Tensor input_11 = MakeInput<int64>(TensorShape({9}),
    149                                        [](int x) -> int64 { return x - 9; });
    150     inputs.push_back({nullptr, &input_11});
    151     // Input #12 is a 1-d string tensor
    152     Tensor input_12 = MakeInput<string>(
    153         TensorShape({2}), [](int x) -> string { return x ? "yes" : "no"; });
    154     inputs.push_back({nullptr, &input_12});
    155     // Input #13 is a 1-d complex64 tensor
    156     Tensor input_13 = MakeInput<complex64>(
    157         TensorShape({2, 3}),
    158         [](int x) -> complex64 { return complex64(100 + x, 200 + x); });
    159     inputs.push_back({nullptr, &input_13});
    160     // Input #14 is a 2-d half tensor
    161     Tensor input_14 =
    162         MakeInput<Eigen::half>(TensorShape({2, 4}), [](int x) -> Eigen::half {
    163           return static_cast<Eigen::half>(x) / Eigen::half(5);
    164         });
    165     inputs.push_back({nullptr, &input_14});
    166     // Input #15 is a 2-d empty float tensor
    167     Tensor input_15 = MakeInput<float>(TensorShape({2, 0}), [](int x) -> float {
    168       return static_cast<float>(x) / 10;
    169     });
    170     inputs.push_back({nullptr, &input_15});
    171     OpKernelContext::Params params;
    172     params.device = device.get();
    173     params.frame_iter = FrameAndIter(0, 0);
    174     params.inputs = &inputs;
    175     params.op_kernel = op.get();
    176     std::vector<AllocatorAttributes> attrs;
    177     test::SetOutputAttrs(&params, &attrs);
    178     checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
    179     params.slice_reader_cache = &slice_reader_cache_wrapper;
    180 
    181     OpKernelContext ctx(&params);
    182     op->Compute(&ctx);
    183     TF_EXPECT_OK(ctx.status());
    184   }
    185 
    186   // Now we restore
    187 
    188   // The 1-d bool tensor
    189   {
    190     MakeRestoreOp(DT_BOOL);
    191     AddInput<string>(TensorShape({}),
    192                      [&filename](int x) -> string { return filename; });
    193     AddInput<string>(TensorShape({}),
    194                      [&](int x) -> string { return tensor_names[0]; });
    195     TF_ASSERT_OK(RunOpKernel());
    196     Tensor* output = GetOutput(0);
    197     TensorShape expected({2});
    198     EXPECT_TRUE(output->shape().IsSameSize(expected));
    199     for (int i = 0; i < 2; ++i) {
    200       EXPECT_EQ(i != 0, output->flat<bool>()(i));
    201     }
    202   }
    203   // The 1-d integer tensor
    204   {
    205     MakeRestoreOp(DT_INT32);
    206     (*mutable_input(1).tensor).scalar<string>()() = tensor_names[1];
    207     TF_ASSERT_OK(RunOpKernel());
    208     Tensor* output = GetOutput(0);
    209     TensorShape expected({10});
    210     EXPECT_TRUE(output->shape().IsSameSize(expected));
    211     for (int i = 0; i < 10; ++i) {
    212       EXPECT_EQ(i + 1, output->flat<int32>()(i));
    213     }
    214   }
    215   // The 2-d float tensor
    216   {
    217     MakeRestoreOp(DT_FLOAT);
    218     (*mutable_input(1).tensor).scalar<string>()() = tensor_names[2];
    219     TF_ASSERT_OK(RunOpKernel());
    220     Tensor* output = GetOutput(0);
    221     TensorShape expected({2, 4});
    222     EXPECT_TRUE(output->shape().IsSameSize(expected));
    223     for (int i = 0; i < 8; ++i) {
    224       EXPECT_EQ(static_cast<float>(i) / 10, output->flat<float>()(i));
    225     }
    226   }
    227   // The 2-d double tensor
    228   {
    229     MakeRestoreOp(DT_DOUBLE);
    230     (*mutable_input(1).tensor).scalar<string>()() = tensor_names[3];
    231     TF_ASSERT_OK(RunOpKernel());
    232     Tensor* output = GetOutput(0);
    233     TensorShape expected({2, 4});
    234     EXPECT_TRUE(output->shape().IsSameSize(expected));
    235     for (int i = 0; i < 8; ++i) {
    236       EXPECT_EQ(static_cast<double>(i) / 20, output->flat<double>()(i));
    237     }
    238   }
    239   // The 2-d qint8 tensor
    240   {
    241     MakeRestoreOp(DT_QINT8);
    242     (*mutable_input(1).tensor).scalar<string>()() = tensor_names[4];
    243     TF_ASSERT_OK(RunOpKernel());
    244     Tensor* output = GetOutput(0);
    245     TensorShape expected({3, 2});
    246     EXPECT_TRUE(output->shape().IsSameSize(expected));
    247     for (int i = 0; i < 6; ++i) {
    248       EXPECT_EQ(*reinterpret_cast<qint8*>(&i), output->flat<qint8>()(i));
    249     }
    250   }
    251   // The 2-d qint32 tensor
    252   {
    253     MakeRestoreOp(DT_QINT32);
    254     (*mutable_input(1).tensor).scalar<string>()() = tensor_names[5];
    255     TF_ASSERT_OK(RunOpKernel());
    256     Tensor* output = GetOutput(0);
    257     TensorShape expected({2, 3});
    258     EXPECT_TRUE(output->shape().IsSameSize(expected));
    259     for (int i = 0; i < 6; ++i) {
    260       EXPECT_EQ(*reinterpret_cast<qint32*>(&i) * qint8(2),
    261                 output->flat<qint32>()(i));
    262     }
    263   }
    264   // The 1-d uint8 tensor
    265   {
    266     MakeRestoreOp(DT_UINT8);
    267     (*mutable_input(1).tensor).scalar<string>()() = tensor_names[6];
    268     TF_ASSERT_OK(RunOpKernel());
    269     Tensor* output = GetOutput(0);
    270     TensorShape expected({11});
    271     EXPECT_TRUE(output->shape().IsSameSize(expected));
    272     for (int i = 0; i < 11; ++i) {
    273       EXPECT_EQ(i + 1, output->flat<uint8>()(i));
    274     }
    275   }
    276   // The 1-d int8 tensor
    277   {
    278     MakeRestoreOp(DT_INT8);
    279     (*mutable_input(1).tensor).scalar<string>()() = tensor_names[7];
    280     TF_ASSERT_OK(RunOpKernel());
    281     Tensor* output = GetOutput(0);
    282     TensorShape expected({7});
    283     EXPECT_TRUE(output->shape().IsSameSize(expected));
    284     for (int i = 0; i < 7; ++i) {
    285       EXPECT_EQ(i - 7, output->flat<int8>()(i));
    286     }
    287   }
    288   // The 1-d int16 tensor
    289   {
    290     MakeRestoreOp(DT_INT16);
    291     (*mutable_input(1).tensor).scalar<string>()() = tensor_names[8];
    292     TF_ASSERT_OK(RunOpKernel());
    293     Tensor* output = GetOutput(0);
    294     TensorShape expected({7});
    295     EXPECT_TRUE(output->shape().IsSameSize(expected));
    296     for (int i = 0; i < 7; ++i) {
    297       EXPECT_EQ(i - 8, output->flat<int16>()(i));
    298     }
    299   }
    300   // The 1-d int64 tensor
    301   {
    302     MakeRestoreOp(DT_INT64);
    303     (*mutable_input(1).tensor).scalar<string>()() = tensor_names[9];
    304     TF_ASSERT_OK(RunOpKernel());
    305     Tensor* output = GetOutput(0);
    306     TensorShape expected({9});
    307     EXPECT_TRUE(output->shape().IsSameSize(expected));
    308     for (int i = 0; i < 9; ++i) {
    309       EXPECT_EQ(i - 9, output->flat<int64>()(i));
    310     }
    311   }
    312   // The 1-d string tensor
    313   {
    314     MakeRestoreOp(DT_STRING);
    315     (*mutable_input(1).tensor).scalar<string>()() = tensor_names[10];
    316     TF_ASSERT_OK(RunOpKernel());
    317     Tensor* output = GetOutput(0);
    318     TensorShape expected({2});
    319     EXPECT_TRUE(output->shape().IsSameSize(expected));
    320     EXPECT_EQ("no", output->flat<string>()(0));
    321     EXPECT_EQ("yes", output->flat<string>()(1));
    322   }
    323   // The 2-d complex64 tensor
    324   {
    325     MakeRestoreOp(DT_COMPLEX64);
    326     (*mutable_input(1).tensor).scalar<string>()() = tensor_names[11];
    327     TF_ASSERT_OK(RunOpKernel());
    328     Tensor* output = GetOutput(0);
    329     TensorShape expected({2, 3});
    330     EXPECT_TRUE(output->shape().IsSameSize(expected));
    331     for (int i = 0; i < 6; ++i) {
    332       EXPECT_EQ(complex64(100 + i, 200 + i), output->flat<complex64>()(i));
    333     }
    334   }
    335   // The 2-d half tensor
    336   {
    337     MakeRestoreOp(DT_HALF);
    338     (*mutable_input(1).tensor).scalar<string>()() = tensor_names[12];
    339     TF_ASSERT_OK(RunOpKernel());
    340     Tensor* output = GetOutput(0);
    341     TensorShape expected({2, 4});
    342     EXPECT_TRUE(output->shape().IsSameSize(expected));
    343     for (int i = 0; i < 8; ++i) {
    344       EXPECT_EQ(static_cast<Eigen::half>(i) / Eigen::half(5),
    345                 output->flat<Eigen::half>()(i));
    346     }
    347   }
    348   // The 2-d empty float tensor
    349   {
    350     MakeRestoreOp(DT_FLOAT);
    351     (*mutable_input(1).tensor).scalar<string>()() = tensor_names[13];
    352     TF_ASSERT_OK(RunOpKernel());
    353     Tensor* output = GetOutput(0);
    354     TensorShape expected({2, 0});
    355     EXPECT_TRUE(output->shape().IsSameSize(expected));
    356   }
    357 }
    358 
    359 class RestoreSliceOpTest : public OpsTestBase {
    360  protected:
    361   void MakeRestoreSliceOp(DataType dt) {
    362     TF_ASSERT_OK(NodeDefBuilder("myop", "RestoreSlice")
    363                      .Input(FakeInput())
    364                      .Input(FakeInput())
    365                      .Input(FakeInput())
    366                      .Attr("dt", dt)
    367                      .Finalize(node_def()));
    368     TF_ASSERT_OK(InitOp());
    369   }
    370 };
    371 
    372 TEST_F(RestoreSliceOpTest, RestoreInt) {
    373   const string filename = io::JoinPath(testing::TmpDir(), "tensor_int");
    374   const string tensor_name = "tensor_int";
    375 
    376   // We first need to write a tensor using the save_op
    377   {
    378     // Initialize an operation
    379     NodeDef save;
    380     TF_ASSERT_OK(NodeDefBuilder("save", "Save")
    381                      .Input(FakeInput(DT_STRING))
    382                      .Input(FakeInput(DT_STRING))
    383                      .Input(FakeInput({DT_INT32}))
    384                      .Finalize(&save));
    385 
    386     std::unique_ptr<Device> device(
    387         DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
    388 
    389     gtl::InlinedVector<TensorValue, 4> inputs;
    390 
    391     Status status;
    392     std::unique_ptr<OpKernel> op(CreateOpKernel(DEVICE_CPU, device.get(),
    393                                                 cpu_allocator(), save,
    394                                                 TF_GRAPH_DEF_VERSION, &status));
    395     TF_EXPECT_OK(status);
    396 
    397     // Run it
    398 
    399     // Input #0 is the file name
    400     Tensor input_0(DT_STRING, TensorShape({}));
    401     input_0.scalar<string>()() = filename;
    402     inputs.push_back({nullptr, &input_0});
    403 
    404     // Input #1 is the tensor name
    405     Tensor input_1(DT_STRING, TensorShape({}));
    406     input_1.scalar<string>()() = tensor_name;
    407     inputs.push_back({nullptr, &input_1});
    408 
    409     // Input #2 is a 4x16 integer tensor.
    410     Tensor input_2(DT_INT32, TensorShape({4, 16}));
    411     for (int64 i = 0; i < input_2.NumElements(); ++i) {
    412       input_2.flat<int32>()(i) = i + 1;
    413     }
    414     inputs.push_back({nullptr, &input_2});
    415 
    416     OpKernelContext::Params params;
    417     params.device = device.get();
    418     params.frame_iter = FrameAndIter(0, 0);
    419     params.inputs = &inputs;
    420     params.op_kernel = op.get();
    421     std::vector<AllocatorAttributes> attrs;
    422     test::SetOutputAttrs(&params, &attrs);
    423     checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
    424     params.slice_reader_cache = &slice_reader_cache_wrapper;
    425 
    426     OpKernelContext ctx(&params);
    427     op->Compute(&ctx);
    428     TF_EXPECT_OK(ctx.status());
    429   }
    430 
    431   // Now we restore
    432   MakeRestoreSliceOp(DT_INT32);
    433   string shape_and_slice = "4 16 0,2:-";
    434   // Add a file name
    435   AddInput<string>(TensorShape({}),
    436                    [&filename](int x) -> string { return filename; });
    437   // Add the tensor names
    438   AddInput<string>(TensorShape({}),
    439                    [&tensor_name](int x) -> string { return tensor_name; });
    440   // Add the tensor shape and slice
    441   AddInput<string>(TensorShape({}), [&shape_and_slice](int x) -> string {
    442     return shape_and_slice;
    443   });
    444 
    445   TF_ASSERT_OK(RunOpKernel());
    446 
    447   // Check that we have an integer tensor
    448   Tensor* output = GetOutput(0);
    449   TensorShape expected({2, 16});
    450   EXPECT_TRUE(output->shape().IsSameSize(expected));
    451   for (int64 i = 0; i < expected.num_elements(); ++i) {
    452     EXPECT_EQ(i + 1, output->flat<int32>()(i));
    453   }
    454 }
    455 
    456 }  // namespace
    457 }  // namespace tensorflow
    458