Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <memory>
     17 #include <string>
     18 #include <vector>
     19 
     20 #include "tensorflow/core/common_runtime/device.h"
     21 #include "tensorflow/core/common_runtime/device_factory.h"
     22 #include "tensorflow/core/framework/allocator.h"
     23 #include "tensorflow/core/framework/fake_input.h"
     24 #include "tensorflow/core/framework/node_def_builder.h"
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 #include "tensorflow/core/framework/tensor.h"
     27 #include "tensorflow/core/framework/types.h"
     28 #include "tensorflow/core/kernels/ops_testutil.h"
     29 
     30 namespace tensorflow {
     31 namespace {
     32 
     33 class SerializeTensorOpTest : public OpsTestBase {
     34  protected:
     35   template <typename T>
     36   void MakeOp(const TensorShape& input_shape, std::function<T(int)> functor) {
     37     TF_ASSERT_OK(NodeDefBuilder("myop", "SerializeTensor")
     38                      .Input(FakeInput(DataTypeToEnum<T>::value))
     39                      .Finalize(node_def()));
     40     TF_ASSERT_OK(InitOp());
     41     AddInput<T>(input_shape, functor);
     42   }
     43   void ParseSerializedWithNodeDef(const NodeDef& parse_node_def,
     44                                   Tensor* serialized, Tensor* parse_output) {
     45     std::unique_ptr<Device> device(
     46         DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
     47     gtl::InlinedVector<TensorValue, 4> inputs;
     48     inputs.push_back({nullptr, serialized});
     49     Status status;
     50     std::unique_ptr<OpKernel> op(CreateOpKernel(DEVICE_CPU, device.get(),
     51                                                 cpu_allocator(), parse_node_def,
     52                                                 TF_GRAPH_DEF_VERSION, &status));
     53     TF_EXPECT_OK(status);
     54     OpKernelContext::Params params;
     55     params.device = device.get();
     56     params.inputs = &inputs;
     57     params.frame_iter = FrameAndIter(0, 0);
     58     params.op_kernel = op.get();
     59     std::vector<AllocatorAttributes> attrs;
     60     test::SetOutputAttrs(&params, &attrs);
     61     OpKernelContext ctx(&params);
     62     op->Compute(&ctx);
     63     TF_EXPECT_OK(status);
     64     *parse_output = *ctx.mutable_output(0);
     65   }
     66   template <typename T>
     67   void ParseSerializedOutput(Tensor* serialized, Tensor* parse_output) {
     68     NodeDef parse;
     69     TF_ASSERT_OK(NodeDefBuilder("parse", "ParseTensor")
     70                      .Input(FakeInput(DT_STRING))
     71                      .Attr("out_type", DataTypeToEnum<T>::value)
     72                      .Finalize(&parse));
     73     ParseSerializedWithNodeDef(parse, serialized, parse_output);
     74   }
     75 };
     76 
     77 TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_half) {
     78   MakeOp<Eigen::half>(TensorShape({10}), [](int x) -> Eigen::half {
     79     return static_cast<Eigen::half>(x / 10.);
     80   });
     81   TF_ASSERT_OK(RunOpKernel());
     82   Tensor parse_output;
     83   ParseSerializedOutput<Eigen::half>(GetOutput(0), &parse_output);
     84   test::ExpectTensorEqual<Eigen::half>(parse_output, GetInput(0));
     85 }
     86 
     87 TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_float) {
     88   MakeOp<float>(TensorShape({1, 10}),
     89                 [](int x) -> float { return static_cast<float>(x / 10.); });
     90   TF_ASSERT_OK(RunOpKernel());
     91   Tensor parse_output;
     92   ParseSerializedOutput<float>(GetOutput(0), &parse_output);
     93   test::ExpectTensorEqual<float>(parse_output, GetInput(0));
     94 }
     95 
     96 TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_double) {
     97   MakeOp<double>(TensorShape({5, 5}),
     98                  [](int x) -> double { return static_cast<double>(x / 10.); });
     99   TF_ASSERT_OK(RunOpKernel());
    100   Tensor parse_output;
    101   ParseSerializedOutput<double>(GetOutput(0), &parse_output);
    102   test::ExpectTensorEqual<double>(parse_output, GetInput(0));
    103 }
    104 
    105 TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int64) {
    106   MakeOp<int64>(TensorShape({2, 3, 4}),
    107                 [](int x) -> int64 { return static_cast<int64>(x - 10); });
    108   TF_ASSERT_OK(RunOpKernel());
    109   Tensor parse_output;
    110   ParseSerializedOutput<int64>(GetOutput(0), &parse_output);
    111   test::ExpectTensorEqual<int64>(parse_output, GetInput(0));
    112 }
    113 
    114 TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int32) {
    115   MakeOp<int32>(TensorShape({4, 2}),
    116                 [](int x) -> int32 { return static_cast<int32>(x + 7); });
    117   TF_ASSERT_OK(RunOpKernel());
    118   Tensor parse_output;
    119   ParseSerializedOutput<int32>(GetOutput(0), &parse_output);
    120   test::ExpectTensorEqual<int32>(parse_output, GetInput(0));
    121 }
    122 
    123 TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int16) {
    124   MakeOp<int16>(TensorShape({8}),
    125                 [](int x) -> int16 { return static_cast<int16>(x + 18); });
    126   TF_ASSERT_OK(RunOpKernel());
    127   Tensor parse_output;
    128   ParseSerializedOutput<int16>(GetOutput(0), &parse_output);
    129   test::ExpectTensorEqual<int16>(parse_output, GetInput(0));
    130 }
    131 
    132 TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int8) {
    133   MakeOp<int8>(TensorShape({2}),
    134                [](int x) -> int8 { return static_cast<int8>(x + 8); });
    135   TF_ASSERT_OK(RunOpKernel());
    136   Tensor parse_output;
    137   ParseSerializedOutput<int8>(GetOutput(0), &parse_output);
    138   test::ExpectTensorEqual<int8>(parse_output, GetInput(0));
    139 }
    140 
    141 TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_uint16) {
    142   MakeOp<uint16>(TensorShape({1, 3}),
    143                  [](int x) -> uint16 { return static_cast<uint16>(x + 2); });
    144   TF_ASSERT_OK(RunOpKernel());
    145   Tensor parse_output;
    146   ParseSerializedOutput<uint16>(GetOutput(0), &parse_output);
    147   test::ExpectTensorEqual<uint16>(parse_output, GetInput(0));
    148 }
    149 
    150 TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_uint8) {
    151   MakeOp<uint8>(TensorShape({2, 1, 1}),
    152                 [](int x) -> uint8 { return static_cast<uint8>(x + 1); });
    153   TF_ASSERT_OK(RunOpKernel());
    154   Tensor parse_output;
    155   ParseSerializedOutput<uint8>(GetOutput(0), &parse_output);
    156   test::ExpectTensorEqual<uint8>(parse_output, GetInput(0));
    157 }
    158 
    159 TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_complex64) {
    160   MakeOp<complex64>(TensorShape({}), [](int x) -> complex64 {
    161     return complex64{static_cast<float>(x / 8.), static_cast<float>(x / 2.)};
    162   });
    163   TF_ASSERT_OK(RunOpKernel());
    164   Tensor parse_output;
    165   ParseSerializedOutput<complex64>(GetOutput(0), &parse_output);
    166   test::ExpectTensorEqual<complex64>(parse_output, GetInput(0));
    167 }
    168 
    169 TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_complex128) {
    170   MakeOp<complex128>(TensorShape({3}), [](int x) -> complex128 {
    171     return complex128{x / 3., x / 2.};
    172   });
    173   TF_ASSERT_OK(RunOpKernel());
    174   Tensor parse_output;
    175   ParseSerializedOutput<complex128>(GetOutput(0), &parse_output);
    176   test::ExpectTensorEqual<complex128>(parse_output, GetInput(0));
    177 }
    178 
    179 TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_bool) {
    180   MakeOp<bool>(TensorShape({1}),
    181                [](int x) -> bool { return static_cast<bool>(x % 2); });
    182   TF_ASSERT_OK(RunOpKernel());
    183   Tensor parse_output;
    184   ParseSerializedOutput<bool>(GetOutput(0), &parse_output);
    185   test::ExpectTensorEqual<bool>(parse_output, GetInput(0));
    186 }
    187 
    188 TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_string) {
    189   MakeOp<string>(TensorShape({10}),
    190                  [](int x) -> string { return std::to_string(x / 10.); });
    191   TF_ASSERT_OK(RunOpKernel());
    192   Tensor parse_output;
    193   ParseSerializedOutput<string>(GetOutput(0), &parse_output);
    194   test::ExpectTensorEqual<string>(parse_output, GetInput(0));
    195 }
    196 
    197 }  // namespace
    198 }  // namespace tensorflow
    199