1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <memory> 17 18 #define EIGEN_USE_THREADS 19 20 #if GOOGLE_CUDA 21 #define EIGEN_USE_GPU 22 #endif 23 24 #include "tensorflow/core/framework/variant_op_registry.h" 25 26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 27 #include "tensorflow/core/framework/op_kernel.h" 28 #include "tensorflow/core/framework/types.h" 29 #include "tensorflow/core/lib/core/status_test_util.h" 30 #include "tensorflow/core/platform/test.h" 31 32 namespace tensorflow { 33 34 typedef Eigen::ThreadPoolDevice CPUDevice; 35 typedef Eigen::GpuDevice GPUDevice; 36 37 namespace { 38 39 struct VariantValue { 40 string TypeName() const { return "TEST VariantValue"; } 41 static Status ShapeFn(const VariantValue& v, TensorShape* s) { 42 if (v.early_exit) { 43 return errors::InvalidArgument("early exit!"); 44 } 45 *s = TensorShape({-0xdeadbeef}); 46 return Status::OK(); 47 } 48 static Status CPUZerosLikeFn(OpKernelContext* ctx, const VariantValue& v, 49 VariantValue* v_out) { 50 if (v.early_exit) { 51 return errors::InvalidArgument("early exit zeros_like!"); 52 } 53 v_out->value = 1; // CPU 54 return Status::OK(); 55 } 56 static Status GPUZerosLikeFn(OpKernelContext* ctx, const VariantValue& v, 57 VariantValue* v_out) { 58 if (v.early_exit) { 59 return errors::InvalidArgument("early exit zeros_like!"); 60 } 61 v_out->value = 2; // GPU 62 return Status::OK(); 63 } 64 static Status CPUAddFn(OpKernelContext* ctx, const VariantValue& a, 65 const VariantValue& b, VariantValue* out) { 66 if (a.early_exit) { 67 return errors::InvalidArgument("early exit add!"); 68 } 69 out->value = a.value + b.value; // CPU 70 return Status::OK(); 71 } 72 static Status GPUAddFn(OpKernelContext* ctx, const VariantValue& a, 73 const VariantValue& b, VariantValue* out) { 74 if (a.early_exit) { 75 return errors::InvalidArgument("early exit add!"); 76 } 77 out->value = -(a.value + b.value); // GPU 78 return Status::OK(); 79 } 80 static Status CPUToGPUCopyFn( 81 const VariantValue& from, VariantValue* to, 82 const std::function<Status(const Tensor&, Tensor*)>& copier) { 83 TF_RETURN_IF_ERROR(copier(Tensor(), nullptr)); 84 to->value = 0xdeadbeef; 85 return Status::OK(); 86 } 87 bool early_exit; 88 int value; 89 }; 90 91 REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, "TEST VariantValue", 92 VariantValue::ShapeFn); 93 94 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantValue, "TEST VariantValue"); 95 96 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( 97 VariantValue, VariantDeviceCopyDirection::HOST_TO_DEVICE, 98 "TEST VariantValue", VariantValue::CPUToGPUCopyFn); 99 100 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, 101 DEVICE_CPU, VariantValue, 102 "TEST VariantValue", 103 VariantValue::CPUZerosLikeFn); 104 105 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, 106 DEVICE_GPU, VariantValue, 107 "TEST VariantValue", 108 VariantValue::GPUZerosLikeFn); 109 110 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, 111 VariantValue, "TEST VariantValue", 112 VariantValue::CPUAddFn); 113 114 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU, 115 VariantValue, "TEST VariantValue", 116 VariantValue::GPUAddFn); 117 118 } // namespace 119 120 TEST(VariantOpShapeRegistryTest, TestBasic) { 121 EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn("YOU SHALL NOT PASS"), 122 nullptr); 123 124 auto* shape_fn = 125 UnaryVariantOpRegistry::Global()->GetShapeFn("TEST VariantValue"); 126 EXPECT_NE(shape_fn, nullptr); 127 TensorShape shape; 128 129 VariantValue vv_early_exit{true /* early_exit */}; 130 Variant v = vv_early_exit; 131 Status s0 = (*shape_fn)(v, &shape); 132 EXPECT_FALSE(s0.ok()); 133 EXPECT_TRUE(StringPiece(s0.error_message()).contains("early exit!")); 134 135 VariantValue vv_ok{false /* early_exit */}; 136 v = vv_ok; 137 TF_EXPECT_OK((*shape_fn)(v, &shape)); 138 EXPECT_EQ(shape, TensorShape({-0xdeadbeef})); 139 } 140 141 TEST(VariantOpShapeRegistryTest, TestDuplicate) { 142 UnaryVariantOpRegistry registry; 143 UnaryVariantOpRegistry::VariantShapeFn f; 144 string kTypeName = "fjfjfj"; 145 registry.RegisterShapeFn(kTypeName, f); 146 EXPECT_DEATH(registry.RegisterShapeFn(kTypeName, f), 147 "fjfjfj already registered"); 148 } 149 150 TEST(VariantOpDecodeRegistryTest, TestBasic) { 151 EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetDecodeFn("YOU SHALL NOT PASS"), 152 nullptr); 153 154 auto* decode_fn = 155 UnaryVariantOpRegistry::Global()->GetDecodeFn("TEST VariantValue"); 156 EXPECT_NE(decode_fn, nullptr); 157 158 VariantValue vv{true /* early_exit */}; 159 Variant v = vv; 160 VariantTensorData data; 161 v.Encode(&data); 162 VariantTensorDataProto proto; 163 data.ToProto(&proto); 164 Variant encoded = proto; 165 EXPECT_TRUE((*decode_fn)(&encoded)); 166 VariantValue* decoded = encoded.get<VariantValue>(); 167 EXPECT_NE(decoded, nullptr); 168 EXPECT_EQ(decoded->early_exit, true); 169 } 170 171 TEST(VariantOpDecodeRegistryTest, TestDuplicate) { 172 UnaryVariantOpRegistry registry; 173 UnaryVariantOpRegistry::VariantDecodeFn f; 174 string kTypeName = "fjfjfj"; 175 registry.RegisterDecodeFn(kTypeName, f); 176 EXPECT_DEATH(registry.RegisterDecodeFn(kTypeName, f), 177 "fjfjfj already registered"); 178 } 179 180 TEST(VariantOpCopyToGPURegistryTest, TestBasic) { 181 // No registered copy fn for GPU<->GPU. 182 EXPECT_EQ( 183 UnaryVariantOpRegistry::Global()->GetDeviceCopyFn( 184 VariantDeviceCopyDirection::DEVICE_TO_DEVICE, "TEST VariantValue"), 185 nullptr); 186 187 auto* copy_to_gpu_fn = UnaryVariantOpRegistry::Global()->GetDeviceCopyFn( 188 VariantDeviceCopyDirection::HOST_TO_DEVICE, "TEST VariantValue"); 189 EXPECT_NE(copy_to_gpu_fn, nullptr); 190 191 VariantValue vv{true /* early_exit */}; 192 Variant v = vv; 193 Variant v_out; 194 bool dummy_executed = false; 195 auto dummy_copy_fn = [&dummy_executed](const Tensor& from, 196 Tensor* to) -> Status { 197 dummy_executed = true; 198 return Status::OK(); 199 }; 200 TF_EXPECT_OK((*copy_to_gpu_fn)(v, &v_out, dummy_copy_fn)); 201 EXPECT_TRUE(dummy_executed); 202 VariantValue* copied_value = v_out.get<VariantValue>(); 203 EXPECT_NE(copied_value, nullptr); 204 EXPECT_EQ(copied_value->value, 0xdeadbeef); 205 } 206 207 TEST(VariantOpCopyToGPURegistryTest, TestDuplicate) { 208 UnaryVariantOpRegistry registry; 209 UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn f; 210 string kTypeName = "fjfjfj"; 211 registry.RegisterDeviceCopyFn(VariantDeviceCopyDirection::HOST_TO_DEVICE, 212 kTypeName, f); 213 EXPECT_DEATH(registry.RegisterDeviceCopyFn( 214 VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeName, f), 215 "fjfjfj already registered"); 216 } 217 218 TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) { 219 EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn( 220 ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"), 221 nullptr); 222 223 VariantValue vv_early_exit{true /* early_exit */, 0 /* value */}; 224 Variant v = vv_early_exit; 225 Variant v_out = VariantValue(); 226 227 OpKernelContext* null_context_pointer = nullptr; 228 Status s0 = UnaryOpVariant<CPUDevice>(null_context_pointer, 229 ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out); 230 EXPECT_FALSE(s0.ok()); 231 EXPECT_TRUE( 232 StringPiece(s0.error_message()).contains("early exit zeros_like")); 233 234 VariantValue vv_ok{false /* early_exit */, 0 /* value */}; 235 v = vv_ok; 236 TF_EXPECT_OK(UnaryOpVariant<CPUDevice>( 237 null_context_pointer, ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out)); 238 VariantValue* vv_out = CHECK_NOTNULL(v_out.get<VariantValue>()); 239 EXPECT_EQ(vv_out->value, 1); // CPU 240 } 241 242 #if GOOGLE_CUDA 243 TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) { 244 EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn( 245 ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"), 246 nullptr); 247 248 VariantValue vv_early_exit{true /* early_exit */, 0 /* value */}; 249 Variant v = vv_early_exit; 250 Variant v_out = VariantValue(); 251 252 OpKernelContext* null_context_pointer = nullptr; 253 Status s0 = UnaryOpVariant<GPUDevice>(null_context_pointer, 254 ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out); 255 EXPECT_FALSE(s0.ok()); 256 EXPECT_TRUE( 257 StringPiece(s0.error_message()).contains("early exit zeros_like")); 258 259 VariantValue vv_ok{false /* early_exit */, 0 /* value */}; 260 v = vv_ok; 261 TF_EXPECT_OK(UnaryOpVariant<GPUDevice>( 262 null_context_pointer, ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out)); 263 VariantValue* vv_out = CHECK_NOTNULL(v_out.get<VariantValue>()); 264 EXPECT_EQ(vv_out->value, 2); // GPU 265 } 266 #endif // GOOGLE_CUDA 267 268 TEST(VariantOpUnaryOpRegistryTest, TestDuplicate) { 269 UnaryVariantOpRegistry registry; 270 UnaryVariantOpRegistry::VariantUnaryOpFn f; 271 string kTypeName = "fjfjfj"; 272 273 registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, kTypeName, 274 f); 275 EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, 276 DEVICE_CPU, kTypeName, f), 277 "fjfjfj already registered"); 278 279 registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, kTypeName, 280 f); 281 EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, 282 DEVICE_GPU, kTypeName, f), 283 "fjfjfj already registered"); 284 } 285 286 TEST(VariantOpAddRegistryTest, TestBasicCPU) { 287 return; 288 EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn( 289 ADD_VARIANT_BINARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"), 290 nullptr); 291 292 VariantValue vv_early_exit{true /* early_exit */, 3 /* value */}; 293 VariantValue vv_other{true /* early_exit */, 4 /* value */}; 294 Variant v_a = vv_early_exit; 295 Variant v_b = vv_other; 296 Variant v_out = VariantValue(); 297 298 OpKernelContext* null_context_pointer = nullptr; 299 Status s0 = BinaryOpVariants<CPUDevice>( 300 null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out); 301 EXPECT_FALSE(s0.ok()); 302 EXPECT_TRUE(StringPiece(s0.error_message()).contains("early exit add")); 303 304 VariantValue vv_ok{false /* early_exit */, 3 /* value */}; 305 v_a = vv_ok; 306 TF_EXPECT_OK(BinaryOpVariants<CPUDevice>( 307 null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out)); 308 VariantValue* vv_out = CHECK_NOTNULL(v_out.get<VariantValue>()); 309 EXPECT_EQ(vv_out->value, 7); // CPU 310 } 311 312 #if GOOGLE_CUDA 313 TEST(VariantOpAddRegistryTest, TestBasicGPU) { 314 EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn( 315 ADD_VARIANT_BINARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"), 316 nullptr); 317 318 VariantValue vv_early_exit{true /* early_exit */, 3 /* value */}; 319 VariantValue vv_other{true /* early_exit */, 4 /* value */}; 320 Variant v_a = vv_early_exit; 321 Variant v_b = vv_other; 322 Variant v_out = VariantValue(); 323 324 OpKernelContext* null_context_pointer = nullptr; 325 Status s0 = BinaryOpVariants<GPUDevice>( 326 null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out); 327 EXPECT_FALSE(s0.ok()); 328 EXPECT_TRUE(StringPiece(s0.error_message()).contains("early exit add")); 329 330 VariantValue vv_ok{false /* early_exit */, 3 /* value */}; 331 v_a = vv_ok; 332 TF_EXPECT_OK(BinaryOpVariants<GPUDevice>( 333 null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out)); 334 VariantValue* vv_out = CHECK_NOTNULL(v_out.get<VariantValue>()); 335 EXPECT_EQ(vv_out->value, -7); // GPU 336 } 337 #endif // GOOGLE_CUDA 338 339 TEST(VariantOpAddRegistryTest, TestDuplicate) { 340 UnaryVariantOpRegistry registry; 341 UnaryVariantOpRegistry::VariantBinaryOpFn f; 342 string kTypeName = "fjfjfj"; 343 344 registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeName, f); 345 EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, 346 kTypeName, f), 347 "fjfjfj already registered"); 348 349 registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeName, f); 350 EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, 351 kTypeName, f), 352 "fjfjfj already registered"); 353 } 354 355 } // namespace tensorflow 356