1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <deque> 17 #include <vector> 18 19 #include "tensorflow/core/common_runtime/device.h" 20 #include "tensorflow/core/common_runtime/executor.h" 21 #include "tensorflow/core/common_runtime/function.h" 22 #include "tensorflow/core/common_runtime/memory_types.h" 23 #include "tensorflow/core/framework/function.h" 24 #include "tensorflow/core/framework/op.h" 25 #include "tensorflow/core/framework/op_kernel.h" 26 #include "tensorflow/core/framework/register_types.h" 27 #include "tensorflow/core/graph/algorithm.h" 28 #include "tensorflow/core/graph/gradients.h" 29 #include "tensorflow/core/graph/graph_constructor.h" 30 #include "tensorflow/core/platform/macros.h" 31 #include "tensorflow/core/util/device_name_utils.h" 32 33 namespace tensorflow { 34 35 static const char* const kGradientOp = "SymbolicGradient"; 36 37 class ArgOp : public OpKernel { 38 public: 39 explicit ArgOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 40 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); 41 OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_)); 42 } 43 44 void Compute(OpKernelContext* ctx) override { 45 auto frame = ctx->call_frame(); 46 OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame")); 47 Tensor val; 48 OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val)); 49 OP_REQUIRES(ctx, val.dtype() == dtype_, 50 errors::InvalidArgument( 51 "Type mismatch: actual ", DataTypeString(val.dtype()), 52 " vs. expect ", DataTypeString(dtype_))); 53 ctx->set_output(0, val); 54 } 55 56 bool IsExpensive() override { return false; } 57 58 private: 59 int index_; 60 DataType dtype_; 61 62 TF_DISALLOW_COPY_AND_ASSIGN(ArgOp); 63 }; 64 65 class RetvalOp : public OpKernel { 66 public: 67 explicit RetvalOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 68 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); 69 OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_)); 70 } 71 72 void Compute(OpKernelContext* ctx) override { 73 const Tensor& val = ctx->input(0); 74 OP_REQUIRES(ctx, val.dtype() == dtype_, 75 errors::InvalidArgument( 76 "Type mismatch: actual ", DataTypeString(val.dtype()), 77 " vs. expect ", DataTypeString(dtype_))); 78 auto frame = ctx->call_frame(); 79 OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame")); 80 OP_REQUIRES_OK(ctx, frame->SetRetval(index_, val)); 81 } 82 83 bool IsExpensive() override { return false; } 84 85 private: 86 int index_; 87 DataType dtype_; 88 89 TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp); 90 }; 91 92 REGISTER_SYSTEM_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_CPU), ArgOp); 93 REGISTER_SYSTEM_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_CPU), RetvalOp); 94 95 #if TENSORFLOW_USE_SYCL 96 #define REGISTER(type) \ 97 REGISTER_KERNEL_BUILDER( \ 98 Name("_Arg").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ArgOp); 99 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER) 100 TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Arg") 101 .Device(DEVICE_SYCL) 102 .HostMemory("output") 103 .TypeConstraint<int32>("T"), 104 ArgOp); 105 #undef REGISTER 106 #define REGISTER(type) \ 107 REGISTER_KERNEL_BUILDER( \ 108 Name("_Retval").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ 109 RetvalOp); 110 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER) 111 TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Retval") 112 .Device(DEVICE_SYCL) 113 .HostMemory("input") 114 .TypeConstraint<int32>("T"), 115 RetvalOp); 116 #undef REGISTER 117 #endif 118 119 #define REGISTER(type) \ 120 REGISTER_KERNEL_BUILDER( \ 121 Name("_Arg").Device(DEVICE_GPU).TypeConstraint<type>("T"), ArgOp); 122 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER) 123 TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Arg") 124 .Device(DEVICE_GPU) 125 .HostMemory("output") 126 .TypeConstraint<int32>("T"), 127 ArgOp); 128 #undef REGISTER 129 130 REGISTER_KERNEL_BUILDER(Name("_Arg") 131 .Device(DEVICE_GPU) 132 .HostMemory("output") 133 .TypeConstraint<ResourceHandle>("T"), 134 ArgOp); 135 136 #define REGISTER(type) \ 137 REGISTER_KERNEL_BUILDER( \ 138 Name("_Retval").Device(DEVICE_GPU).TypeConstraint<type>("T"), RetvalOp); 139 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER) 140 TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Retval") 141 .Device(DEVICE_GPU) 142 .HostMemory("input") 143 .TypeConstraint<int32>("T"), 144 RetvalOp); 145 #undef REGISTER 146 147 class PassOn : public OpKernel { 148 public: 149 explicit PassOn(OpKernelConstruction* ctx) : OpKernel(ctx) { 150 OP_REQUIRES(ctx, ctx->num_inputs() == ctx->num_outputs(), 151 errors::Internal("#inputs != #outputs : ", ctx->num_inputs(), 152 " vs. ", ctx->num_outputs())); 153 for (int i = 0; i < ctx->num_inputs(); ++i) { 154 OP_REQUIRES( 155 ctx, input_type(i) == output_type(i), 156 errors::Internal("Input and output types for position ", i, 157 " do not match: ", DataTypeString(input_type(i)), 158 " vs. ", DataTypeString(output_type(i)))); 159 } 160 } 161 162 void Compute(OpKernelContext* ctx) override { 163 for (int i = 0; i < ctx->num_inputs(); ++i) { 164 ctx->set_output(i, ctx->input(i)); 165 } 166 } 167 }; 168 169 REGISTER_KERNEL_BUILDER(Name("_ListToArray").Device(DEVICE_CPU), PassOn); 170 REGISTER_KERNEL_BUILDER(Name("_ArrayToList").Device(DEVICE_CPU), PassOn); 171 172 #define REGISTER_GPU_KERNELS(type) \ 173 REGISTER_KERNEL_BUILDER( \ 174 Name("_ListToArray").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 175 PassOn); \ 176 REGISTER_KERNEL_BUILDER( \ 177 Name("_ArrayToList").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 178 PassOn); 179 180 REGISTER_GPU_KERNELS(Eigen::half); 181 REGISTER_GPU_KERNELS(float); 182 REGISTER_GPU_KERNELS(double); 183 184 #undef REGISTER_GPU_KERNELS 185 186 REGISTER_KERNEL_BUILDER(Name("_ListToArray") 187 .Device(DEVICE_GPU) 188 .HostMemory("input") 189 .HostMemory("output") 190 .TypeConstraint<int32>("T"), 191 PassOn); 192 REGISTER_KERNEL_BUILDER(Name("_ArrayToList") 193 .Device(DEVICE_GPU) 194 .HostMemory("input") 195 .HostMemory("output") 196 .TypeConstraint<int32>("T"), 197 PassOn); 198 199 #ifdef TENSORFLOW_USE_SYCL 200 #define REGISTER_SYCL_KERNELS(type) \ 201 REGISTER_KERNEL_BUILDER( \ 202 Name("_ListToArray").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ 203 PassOn); \ 204 REGISTER_KERNEL_BUILDER( \ 205 Name("_ArrayToList").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ 206 PassOn); 207 208 REGISTER_SYCL_KERNELS(float); 209 REGISTER_SYCL_KERNELS(double); 210 211 #undef REGISTER_SYCL_KERNELS 212 213 REGISTER_KERNEL_BUILDER(Name("_ListToArray") 214 .Device(DEVICE_SYCL) 215 .HostMemory("input") 216 .HostMemory("output") 217 .TypeConstraint<int32>("T"), 218 PassOn); 219 REGISTER_KERNEL_BUILDER(Name("_ArrayToList") 220 .Device(DEVICE_SYCL) 221 .HostMemory("input") 222 .HostMemory("output") 223 .TypeConstraint<int32>("T"), 224 PassOn); 225 #endif // TENSORFLOW_USE_SYCL 226 227 class SymbolicGradientOp : public AsyncOpKernel { 228 public: 229 explicit SymbolicGradientOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {} 230 231 ~SymbolicGradientOp() override {} 232 233 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { 234 FunctionLibraryRuntime* lib = ctx->function_library(); 235 OP_REQUIRES_ASYNC(ctx, lib != nullptr, 236 errors::Internal("No function library is provided."), 237 done); 238 239 FunctionLibraryRuntime::Handle handle; 240 OP_REQUIRES_OK_ASYNC( 241 ctx, lib->Instantiate(kGradientOp, AttrSlice(def()), &handle), done); 242 243 FunctionLibraryRuntime::Options opts; 244 opts.step_id = ctx->step_id(); 245 opts.rendezvous = ctx->rendezvous(); 246 opts.cancellation_manager = ctx->cancellation_manager(); 247 opts.runner = ctx->runner(); 248 opts.stats_collector = ctx->stats_collector(); 249 opts.step_container = ctx->step_container(); 250 std::vector<Tensor> args; 251 args.reserve(ctx->num_inputs()); 252 for (int i = 0; i < ctx->num_inputs(); ++i) { 253 args.push_back(ctx->input(i)); 254 } 255 std::vector<Tensor>* rets = new std::vector<Tensor>; 256 lib->Run(opts, handle, args, rets, [ctx, done, rets](const Status& status) { 257 if (!status.ok()) { 258 ctx->SetStatus(status); 259 } else if (rets->size() != ctx->num_outputs()) { 260 ctx->SetStatus(errors::InvalidArgument( 261 "SymGrad expects to return ", ctx->num_outputs(), 262 " tensor(s), but get ", rets->size(), " tensor(s) instead.")); 263 } else { 264 for (size_t i = 0; i < rets->size(); ++i) { 265 ctx->set_output(i, (*rets)[i]); 266 } 267 } 268 delete rets; 269 done(); 270 }); 271 } 272 273 private: 274 TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientOp); 275 }; 276 277 REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_CPU), 278 SymbolicGradientOp); 279 REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_GPU), 280 SymbolicGradientOp); 281 #if TENSORFLOW_USE_SYCL 282 REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_SYCL), 283 SymbolicGradientOp); 284 285 #endif // TENSORFLOW_USE_SYCL 286 287 class RemoteCallOp : public AsyncOpKernel { 288 public: 289 explicit RemoteCallOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { 290 OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); 291 } 292 293 ~RemoteCallOp() override {} 294 295 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { 296 const Tensor* target; 297 OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done); 298 const string& target_device = 299 DeviceNameUtils::CanonicalizeDeviceName(target->scalar<string>()()); 300 301 FunctionLibraryRuntime* lib = ctx->function_library(); 302 OP_REQUIRES_ASYNC(ctx, lib != nullptr, 303 errors::Internal("No function library is provided."), 304 done); 305 AttrValueMap attr_values = func_.attr(); 306 FunctionLibraryRuntime::InstantiateOptions instantiate_opts; 307 instantiate_opts.target = target_device; 308 FunctionLibraryRuntime::Handle handle; 309 OP_REQUIRES_OK_ASYNC(ctx, 310 lib->Instantiate(func_.name(), AttrSlice(&attr_values), 311 instantiate_opts, &handle), 312 done); 313 314 OpInputList arguments; 315 OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done); 316 317 FunctionLibraryRuntime::Options opts; 318 opts.step_id = ctx->step_id(); 319 opts.runner = ctx->runner(); 320 opts.source_device = lib->device()->name(); 321 if (opts.source_device != target_device) { 322 opts.remote_execution = true; 323 } 324 opts.create_rendezvous = true; 325 std::vector<Tensor> args; 326 args.reserve(arguments.size()); 327 for (const Tensor& argument : arguments) { 328 args.push_back(argument); 329 } 330 auto* rets = new std::vector<Tensor>; 331 lib->Run(opts, handle, args, rets, [rets, done, ctx](const Status& status) { 332 if (!status.ok()) { 333 ctx->SetStatus(status); 334 } else { 335 for (size_t i = 0; i < rets->size(); ++i) { 336 ctx->set_output(i, (*rets)[i]); 337 } 338 } 339 delete rets; 340 done(); 341 }); 342 } 343 344 private: 345 string target_; 346 NameAttrList func_; 347 TF_DISALLOW_COPY_AND_ASSIGN(RemoteCallOp); 348 }; 349 350 REGISTER_KERNEL_BUILDER( 351 Name("RemoteCall").Device(DEVICE_CPU).HostMemory("target"), RemoteCallOp); 352 REGISTER_KERNEL_BUILDER( 353 Name("RemoteCall").Device(DEVICE_GPU).HostMemory("target"), RemoteCallOp); 354 #if TENSORFLOW_USE_SYCL 355 REGISTER_KERNEL_BUILDER( 356 Name("RemoteCall").Device(DEVICE_SYCL).HostMemory("target"), RemoteCallOp); 357 358 #endif // TENSORFLOW_USE_SYCL 359 } // namespace tensorflow 360