Home | History | Annotate | Download | only in framework
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <memory>
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #if GOOGLE_CUDA
     21 #define EIGEN_USE_GPU
     22 #endif
     23 
     24 #include "tensorflow/core/framework/variant_op_registry.h"
     25 
     26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     27 #include "tensorflow/core/framework/op_kernel.h"
     28 #include "tensorflow/core/framework/types.h"
     29 #include "tensorflow/core/lib/core/status_test_util.h"
     30 #include "tensorflow/core/platform/test.h"
     31 
     32 namespace tensorflow {
     33 
     34 typedef Eigen::ThreadPoolDevice CPUDevice;
     35 typedef Eigen::GpuDevice GPUDevice;
     36 
     37 namespace {
     38 
     39 struct VariantValue {
     40   string TypeName() const { return "TEST VariantValue"; }
     41   static Status ShapeFn(const VariantValue& v, TensorShape* s) {
     42     if (v.early_exit) {
     43       return errors::InvalidArgument("early exit!");
     44     }
     45     *s = TensorShape({-0xdeadbeef});
     46     return Status::OK();
     47   }
     48   static Status CPUZerosLikeFn(OpKernelContext* ctx, const VariantValue& v,
     49                                VariantValue* v_out) {
     50     if (v.early_exit) {
     51       return errors::InvalidArgument("early exit zeros_like!");
     52     }
     53     v_out->value = 1;  // CPU
     54     return Status::OK();
     55   }
     56   static Status GPUZerosLikeFn(OpKernelContext* ctx, const VariantValue& v,
     57                                VariantValue* v_out) {
     58     if (v.early_exit) {
     59       return errors::InvalidArgument("early exit zeros_like!");
     60     }
     61     v_out->value = 2;  // GPU
     62     return Status::OK();
     63   }
     64   static Status CPUAddFn(OpKernelContext* ctx, const VariantValue& a,
     65                          const VariantValue& b, VariantValue* out) {
     66     if (a.early_exit) {
     67       return errors::InvalidArgument("early exit add!");
     68     }
     69     out->value = a.value + b.value;  // CPU
     70     return Status::OK();
     71   }
     72   static Status GPUAddFn(OpKernelContext* ctx, const VariantValue& a,
     73                          const VariantValue& b, VariantValue* out) {
     74     if (a.early_exit) {
     75       return errors::InvalidArgument("early exit add!");
     76     }
     77     out->value = -(a.value + b.value);  // GPU
     78     return Status::OK();
     79   }
     80   static Status CPUToGPUCopyFn(
     81       const VariantValue& from, VariantValue* to,
     82       const std::function<Status(const Tensor&, Tensor*)>& copier) {
     83     TF_RETURN_IF_ERROR(copier(Tensor(), nullptr));
     84     to->value = 0xdeadbeef;
     85     return Status::OK();
     86   }
     87   bool early_exit;
     88   int value;
     89 };
     90 
     91 REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, "TEST VariantValue",
     92                                       VariantValue::ShapeFn);
     93 
     94 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantValue, "TEST VariantValue");
     95 
     96 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
     97     VariantValue, VariantDeviceCopyDirection::HOST_TO_DEVICE,
     98     "TEST VariantValue", VariantValue::CPUToGPUCopyFn);
     99 
    100 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
    101                                          DEVICE_CPU, VariantValue,
    102                                          "TEST VariantValue",
    103                                          VariantValue::CPUZerosLikeFn);
    104 
    105 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
    106                                          DEVICE_GPU, VariantValue,
    107                                          "TEST VariantValue",
    108                                          VariantValue::GPUZerosLikeFn);
    109 
    110 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
    111                                           VariantValue, "TEST VariantValue",
    112                                           VariantValue::CPUAddFn);
    113 
    114 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
    115                                           VariantValue, "TEST VariantValue",
    116                                           VariantValue::GPUAddFn);
    117 
    118 }  // namespace
    119 
    120 TEST(VariantOpShapeRegistryTest, TestBasic) {
    121   EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn("YOU SHALL NOT PASS"),
    122             nullptr);
    123 
    124   auto* shape_fn =
    125       UnaryVariantOpRegistry::Global()->GetShapeFn("TEST VariantValue");
    126   EXPECT_NE(shape_fn, nullptr);
    127   TensorShape shape;
    128 
    129   VariantValue vv_early_exit{true /* early_exit */};
    130   Variant v = vv_early_exit;
    131   Status s0 = (*shape_fn)(v, &shape);
    132   EXPECT_FALSE(s0.ok());
    133   EXPECT_TRUE(StringPiece(s0.error_message()).contains("early exit!"));
    134 
    135   VariantValue vv_ok{false /* early_exit */};
    136   v = vv_ok;
    137   TF_EXPECT_OK((*shape_fn)(v, &shape));
    138   EXPECT_EQ(shape, TensorShape({-0xdeadbeef}));
    139 }
    140 
    141 TEST(VariantOpShapeRegistryTest, TestDuplicate) {
    142   UnaryVariantOpRegistry registry;
    143   UnaryVariantOpRegistry::VariantShapeFn f;
    144   string kTypeName = "fjfjfj";
    145   registry.RegisterShapeFn(kTypeName, f);
    146   EXPECT_DEATH(registry.RegisterShapeFn(kTypeName, f),
    147                "fjfjfj already registered");
    148 }
    149 
    150 TEST(VariantOpDecodeRegistryTest, TestBasic) {
    151   EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetDecodeFn("YOU SHALL NOT PASS"),
    152             nullptr);
    153 
    154   auto* decode_fn =
    155       UnaryVariantOpRegistry::Global()->GetDecodeFn("TEST VariantValue");
    156   EXPECT_NE(decode_fn, nullptr);
    157 
    158   VariantValue vv{true /* early_exit */};
    159   Variant v = vv;
    160   VariantTensorData data;
    161   v.Encode(&data);
    162   VariantTensorDataProto proto;
    163   data.ToProto(&proto);
    164   Variant encoded = proto;
    165   EXPECT_TRUE((*decode_fn)(&encoded));
    166   VariantValue* decoded = encoded.get<VariantValue>();
    167   EXPECT_NE(decoded, nullptr);
    168   EXPECT_EQ(decoded->early_exit, true);
    169 }
    170 
    171 TEST(VariantOpDecodeRegistryTest, TestDuplicate) {
    172   UnaryVariantOpRegistry registry;
    173   UnaryVariantOpRegistry::VariantDecodeFn f;
    174   string kTypeName = "fjfjfj";
    175   registry.RegisterDecodeFn(kTypeName, f);
    176   EXPECT_DEATH(registry.RegisterDecodeFn(kTypeName, f),
    177                "fjfjfj already registered");
    178 }
    179 
    180 TEST(VariantOpCopyToGPURegistryTest, TestBasic) {
    181   // No registered copy fn for GPU<->GPU.
    182   EXPECT_EQ(
    183       UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(
    184           VariantDeviceCopyDirection::DEVICE_TO_DEVICE, "TEST VariantValue"),
    185       nullptr);
    186 
    187   auto* copy_to_gpu_fn = UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(
    188       VariantDeviceCopyDirection::HOST_TO_DEVICE, "TEST VariantValue");
    189   EXPECT_NE(copy_to_gpu_fn, nullptr);
    190 
    191   VariantValue vv{true /* early_exit */};
    192   Variant v = vv;
    193   Variant v_out;
    194   bool dummy_executed = false;
    195   auto dummy_copy_fn = [&dummy_executed](const Tensor& from,
    196                                          Tensor* to) -> Status {
    197     dummy_executed = true;
    198     return Status::OK();
    199   };
    200   TF_EXPECT_OK((*copy_to_gpu_fn)(v, &v_out, dummy_copy_fn));
    201   EXPECT_TRUE(dummy_executed);
    202   VariantValue* copied_value = v_out.get<VariantValue>();
    203   EXPECT_NE(copied_value, nullptr);
    204   EXPECT_EQ(copied_value->value, 0xdeadbeef);
    205 }
    206 
    207 TEST(VariantOpCopyToGPURegistryTest, TestDuplicate) {
    208   UnaryVariantOpRegistry registry;
    209   UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn f;
    210   string kTypeName = "fjfjfj";
    211   registry.RegisterDeviceCopyFn(VariantDeviceCopyDirection::HOST_TO_DEVICE,
    212                                 kTypeName, f);
    213   EXPECT_DEATH(registry.RegisterDeviceCopyFn(
    214                    VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeName, f),
    215                "fjfjfj already registered");
    216 }
    217 
    218 TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) {
    219   EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
    220                 ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"),
    221             nullptr);
    222 
    223   VariantValue vv_early_exit{true /* early_exit */, 0 /* value */};
    224   Variant v = vv_early_exit;
    225   Variant v_out = VariantValue();
    226 
    227   OpKernelContext* null_context_pointer = nullptr;
    228   Status s0 = UnaryOpVariant<CPUDevice>(null_context_pointer,
    229                                         ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out);
    230   EXPECT_FALSE(s0.ok());
    231   EXPECT_TRUE(
    232       StringPiece(s0.error_message()).contains("early exit zeros_like"));
    233 
    234   VariantValue vv_ok{false /* early_exit */, 0 /* value */};
    235   v = vv_ok;
    236   TF_EXPECT_OK(UnaryOpVariant<CPUDevice>(
    237       null_context_pointer, ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out));
    238   VariantValue* vv_out = CHECK_NOTNULL(v_out.get<VariantValue>());
    239   EXPECT_EQ(vv_out->value, 1);  // CPU
    240 }
    241 
    242 #if GOOGLE_CUDA
    243 TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) {
    244   EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
    245                 ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"),
    246             nullptr);
    247 
    248   VariantValue vv_early_exit{true /* early_exit */, 0 /* value */};
    249   Variant v = vv_early_exit;
    250   Variant v_out = VariantValue();
    251 
    252   OpKernelContext* null_context_pointer = nullptr;
    253   Status s0 = UnaryOpVariant<GPUDevice>(null_context_pointer,
    254                                         ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out);
    255   EXPECT_FALSE(s0.ok());
    256   EXPECT_TRUE(
    257       StringPiece(s0.error_message()).contains("early exit zeros_like"));
    258 
    259   VariantValue vv_ok{false /* early_exit */, 0 /* value */};
    260   v = vv_ok;
    261   TF_EXPECT_OK(UnaryOpVariant<GPUDevice>(
    262       null_context_pointer, ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out));
    263   VariantValue* vv_out = CHECK_NOTNULL(v_out.get<VariantValue>());
    264   EXPECT_EQ(vv_out->value, 2);  // GPU
    265 }
    266 #endif  // GOOGLE_CUDA
    267 
    268 TEST(VariantOpUnaryOpRegistryTest, TestDuplicate) {
    269   UnaryVariantOpRegistry registry;
    270   UnaryVariantOpRegistry::VariantUnaryOpFn f;
    271   string kTypeName = "fjfjfj";
    272 
    273   registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, kTypeName,
    274                              f);
    275   EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP,
    276                                           DEVICE_CPU, kTypeName, f),
    277                "fjfjfj already registered");
    278 
    279   registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, kTypeName,
    280                              f);
    281   EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP,
    282                                           DEVICE_GPU, kTypeName, f),
    283                "fjfjfj already registered");
    284 }
    285 
    286 TEST(VariantOpAddRegistryTest, TestBasicCPU) {
    287   return;
    288   EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn(
    289                 ADD_VARIANT_BINARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"),
    290             nullptr);
    291 
    292   VariantValue vv_early_exit{true /* early_exit */, 3 /* value */};
    293   VariantValue vv_other{true /* early_exit */, 4 /* value */};
    294   Variant v_a = vv_early_exit;
    295   Variant v_b = vv_other;
    296   Variant v_out = VariantValue();
    297 
    298   OpKernelContext* null_context_pointer = nullptr;
    299   Status s0 = BinaryOpVariants<CPUDevice>(
    300       null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out);
    301   EXPECT_FALSE(s0.ok());
    302   EXPECT_TRUE(StringPiece(s0.error_message()).contains("early exit add"));
    303 
    304   VariantValue vv_ok{false /* early_exit */, 3 /* value */};
    305   v_a = vv_ok;
    306   TF_EXPECT_OK(BinaryOpVariants<CPUDevice>(
    307       null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out));
    308   VariantValue* vv_out = CHECK_NOTNULL(v_out.get<VariantValue>());
    309   EXPECT_EQ(vv_out->value, 7);  // CPU
    310 }
    311 
    312 #if GOOGLE_CUDA
    313 TEST(VariantOpAddRegistryTest, TestBasicGPU) {
    314   EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn(
    315                 ADD_VARIANT_BINARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"),
    316             nullptr);
    317 
    318   VariantValue vv_early_exit{true /* early_exit */, 3 /* value */};
    319   VariantValue vv_other{true /* early_exit */, 4 /* value */};
    320   Variant v_a = vv_early_exit;
    321   Variant v_b = vv_other;
    322   Variant v_out = VariantValue();
    323 
    324   OpKernelContext* null_context_pointer = nullptr;
    325   Status s0 = BinaryOpVariants<GPUDevice>(
    326       null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out);
    327   EXPECT_FALSE(s0.ok());
    328   EXPECT_TRUE(StringPiece(s0.error_message()).contains("early exit add"));
    329 
    330   VariantValue vv_ok{false /* early_exit */, 3 /* value */};
    331   v_a = vv_ok;
    332   TF_EXPECT_OK(BinaryOpVariants<GPUDevice>(
    333       null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out));
    334   VariantValue* vv_out = CHECK_NOTNULL(v_out.get<VariantValue>());
    335   EXPECT_EQ(vv_out->value, -7);  // GPU
    336 }
    337 #endif  // GOOGLE_CUDA
    338 
    339 TEST(VariantOpAddRegistryTest, TestDuplicate) {
    340   UnaryVariantOpRegistry registry;
    341   UnaryVariantOpRegistry::VariantBinaryOpFn f;
    342   string kTypeName = "fjfjfj";
    343 
    344   registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeName, f);
    345   EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
    346                                            kTypeName, f),
    347                "fjfjfj already registered");
    348 
    349   registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeName, f);
    350   EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
    351                                            kTypeName, f),
    352                "fjfjfj already registered");
    353 }
    354 
    355 }  // namespace tensorflow
    356