1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/kernels/control_flow_ops.h" 17 18 #include "tensorflow/core/framework/op_kernel.h" 19 #include "tensorflow/core/framework/register_types.h" 20 #include "tensorflow/core/framework/tensor.h" 21 #include "tensorflow/core/framework/types.h" 22 #include "tensorflow/core/platform/macros.h" 23 24 namespace tensorflow { 25 26 void SwitchOp::Compute(OpKernelContext* context) { 27 const Tensor& outputPorts = context->input(1); 28 OP_REQUIRES(context, TensorShapeUtils::IsScalar(outputPorts.shape()), 29 errors::InvalidArgument("The second input must be a scalar, " 30 "but it has shape ", 31 outputPorts.shape().DebugString())); 32 33 bool pred = outputPorts.scalar<bool>()(); 34 int port = (pred) ? 1 : 0; 35 if (context->input_is_ref(0)) { 36 context->forward_ref_input_to_ref_output(0, port); 37 } else { 38 context->set_output(port, context->input(0)); 39 } 40 } 41 42 #define REGISTER_CPU_SWITCH(type) \ 43 REGISTER_KERNEL_BUILDER(Name("Switch") \ 44 .Device(DEVICE_CPU) \ 45 .HostMemory("pred") \ 46 .TypeConstraint<type>("T"), \ 47 SwitchOp) 48 49 #define REGISTER_CPU_REF_SWITCH(type) \ 50 REGISTER_KERNEL_BUILDER(Name("RefSwitch") \ 51 .Device(DEVICE_CPU) \ 52 .HostMemory("pred") \ 53 .TypeConstraint<type>("T"), \ 54 SwitchOp) 55 56 #define REGISTER_GPU_SWITCH(type) \ 57 REGISTER_KERNEL_BUILDER(Name("Switch") \ 58 .Device(DEVICE_GPU) \ 59 .HostMemory("pred") \ 60 .TypeConstraint<type>("T"), \ 61 SwitchOp) 62 63 #define REGISTER_GPU_REF_SWITCH(type) \ 64 REGISTER_KERNEL_BUILDER(Name("RefSwitch") \ 65 .Device(DEVICE_GPU) \ 66 .HostMemory("pred") \ 67 .TypeConstraint<type>("T"), \ 68 SwitchOp) 69 70 TF_CALL_ALL_TYPES(REGISTER_CPU_SWITCH); 71 TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SWITCH); 72 TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_SWITCH); 73 74 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_SWITCH); 75 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_SWITCH); 76 77 #undef REGISTER_CPU_SWITCH 78 #undef REGISTER_CPU_REF_SWITCH 79 #undef REGISTER_GPU_SWITCH 80 #undef REGISTER_GPU_REF_SWITCH 81 82 // Special GPU kernels for int32 and string. 83 // TODO(b/25387198): Also enable int32 in device memory. This kernel 84 // registration requires all int32 inputs and outputs to be in host memory. 85 #define REGISTER_GPU_HOST_KERNEL(type) \ 86 REGISTER_KERNEL_BUILDER(Name("Switch") \ 87 .Device(DEVICE_GPU) \ 88 .HostMemory("data") \ 89 .HostMemory("pred") \ 90 .HostMemory("output_false") \ 91 .HostMemory("output_true") \ 92 .TypeConstraint<type>("T"), \ 93 SwitchOp) 94 95 #define REGISTER_GPU_HOST_REF_KERNEL(type) \ 96 REGISTER_KERNEL_BUILDER(Name("RefSwitch") \ 97 .Device(DEVICE_GPU) \ 98 .HostMemory("data") \ 99 .HostMemory("pred") \ 100 .HostMemory("output_false") \ 101 .HostMemory("output_true") \ 102 .TypeConstraint<type>("T"), \ 103 SwitchOp) 104 105 REGISTER_GPU_HOST_KERNEL(int32); 106 REGISTER_GPU_HOST_REF_KERNEL(int32); 107 REGISTER_GPU_HOST_KERNEL(bool); 108 REGISTER_GPU_HOST_REF_KERNEL(bool); 109 REGISTER_GPU_HOST_KERNEL(string); 110 REGISTER_GPU_HOST_REF_KERNEL(string); 111 112 #undef REGISTER_GPU_HOST_KERNEL 113 #undef REGISTER_GPU_HOST_REF_KERNEL 114 115 #ifdef TENSORFLOW_USE_SYCL 116 #define REGISTER_SYCL_SWITCH(type) \ 117 REGISTER_KERNEL_BUILDER(Name("Switch") \ 118 .Device(DEVICE_SYCL) \ 119 .HostMemory("pred") \ 120 .TypeConstraint<type>("T"), \ 121 SwitchOp) 122 TF_CALL_REAL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_SWITCH); 123 124 #define REGISTER_SYCL_REF_SWITCH(type) \ 125 REGISTER_KERNEL_BUILDER(Name("RefSwitch") \ 126 .Device(DEVICE_SYCL) \ 127 .HostMemory("pred") \ 128 .TypeConstraint<type>("T"), \ 129 SwitchOp) 130 TF_CALL_REAL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_SWITCH); 131 132 #undef REGISTER_SYCL_SWITCH 133 #undef REGISTER_SYCL_REF_SWITCH 134 135 #define REGISTER_SYCL_HOST_KERNEL(type) \ 136 REGISTER_KERNEL_BUILDER(Name("Switch") \ 137 .Device(DEVICE_SYCL) \ 138 .HostMemory("data") \ 139 .HostMemory("pred") \ 140 .HostMemory("output_false") \ 141 .HostMemory("output_true") \ 142 .TypeConstraint<type>("T"), \ 143 SwitchOp) 144 145 REGISTER_SYCL_HOST_KERNEL(bool); 146 REGISTER_SYCL_HOST_KERNEL(string); 147 REGISTER_SYCL_HOST_KERNEL(int32); 148 149 #define REGISTER_SYCL_HOST_REF_KERNEL(type) \ 150 REGISTER_KERNEL_BUILDER(Name("RefSwitch") \ 151 .Device(DEVICE_SYCL) \ 152 .HostMemory("data") \ 153 .HostMemory("pred") \ 154 .HostMemory("output_false") \ 155 .HostMemory("output_true") \ 156 .TypeConstraint<type>("T"), \ 157 SwitchOp) 158 159 REGISTER_SYCL_HOST_REF_KERNEL(int32); 160 REGISTER_SYCL_HOST_REF_KERNEL(bool); 161 REGISTER_SYCL_HOST_REF_KERNEL(string); 162 163 #undef REGISTER_SYCL_HOST_KERNEL 164 #undef REGISTER_SYCL_HOST_REF_KERNEL 165 #endif // TENSORFLOW_USE_SYCL 166 167 class RefSelectOp : public OpKernel { 168 public: 169 explicit RefSelectOp(OpKernelConstruction* context) : OpKernel(context) { 170 OP_REQUIRES_OK(context, context->GetAttr("N", &num_ref_inputs_)); 171 } 172 173 void Compute(OpKernelContext* context) override { 174 const Tensor& index_tensor = context->input(0); 175 OP_REQUIRES(context, TensorShapeUtils::IsScalar(index_tensor.shape()), 176 errors::InvalidArgument("Index must be a scalar, " 177 "but it has shape ", 178 index_tensor.shape().DebugString())); 179 180 int32 index = index_tensor.scalar<int32>()(); 181 182 OP_REQUIRES(context, index >= 0 && index < num_ref_inputs_, 183 errors::InvalidArgument("Index must be in the range [0, ", 184 num_ref_inputs_, ") but got ", index)); 185 context->forward_ref_input_to_ref_output(index + 1, 0); 186 } 187 188 bool IsExpensive() override { return false; } 189 190 ~RefSelectOp() override {} 191 192 TF_DISALLOW_COPY_AND_ASSIGN(RefSelectOp); 193 194 private: 195 int num_ref_inputs_; 196 }; 197 198 #define REGISTER_CPU_REF_SELECT(type) \ 199 REGISTER_KERNEL_BUILDER(Name("RefSelect") \ 200 .Device(DEVICE_CPU) \ 201 .HostMemory("index") \ 202 .TypeConstraint<type>("T"), \ 203 RefSelectOp) 204 TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SELECT); 205 206 #undef REGISTER_CPU_REF_SWITCH 207 208 MergeOp::MergeOp(OpKernelConstruction* context) : OpKernel(context) { 209 const DataType dt = context->input_type(0); 210 const int num_in = context->num_inputs(); 211 OP_REQUIRES_OK(context, context->MatchSignature(DataTypeVector(num_in, dt), 212 {dt, DT_INT32})); 213 } 214 215 void MergeOp::Compute(OpKernelContext* context) { 216 bool input_seen = false; 217 for (int i = 0; i < context->num_inputs(); ++i) { 218 if (context->has_input(i)) { 219 if (input_seen) { 220 context->SetStatus( 221 errors::Internal("Merge can not have more than one valid input.")); 222 return; 223 } 224 input_seen = true; 225 226 if (IsRefType(context->input_dtype(i))) { 227 context->forward_ref_input_to_ref_output(i, 0); 228 } else { 229 context->set_output(0, context->input(i)); 230 } 231 Tensor* value_index = nullptr; 232 OP_REQUIRES_OK( 233 context, context->allocate_output(1, TensorShape({}), &value_index)); 234 value_index->scalar<int32>()() = i; 235 } 236 } 237 } 238 239 REGISTER_KERNEL_BUILDER(Name("Merge").Device(DEVICE_CPU), MergeOp); 240 REGISTER_KERNEL_BUILDER(Name("RefMerge").Device(DEVICE_CPU), MergeOp); 241 242 #define REGISTER_GPU_KERNEL(type) \ 243 REGISTER_KERNEL_BUILDER(Name("Merge") \ 244 .Device(DEVICE_GPU) \ 245 .TypeConstraint<type>("T") \ 246 .HostMemory("value_index"), \ 247 MergeOp); 248 249 #define REGISTER_GPU_REF_KERNEL(type) \ 250 REGISTER_KERNEL_BUILDER(Name("RefMerge") \ 251 .Device(DEVICE_GPU) \ 252 .TypeConstraint<type>("T") \ 253 .HostMemory("value_index"), \ 254 MergeOp); 255 256 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); 257 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL); 258 REGISTER_GPU_KERNEL(bool); 259 REGISTER_GPU_REF_KERNEL(bool); 260 261 #undef REGISTER_GPU_KERNEL 262 #undef REGISTER_GPU_REF_KERNEL 263 264 #ifdef TENSORFLOW_USE_SYCL 265 #define REGISTER_SYCL_KERNEL(type) \ 266 REGISTER_KERNEL_BUILDER(Name("Merge") \ 267 .Device(DEVICE_SYCL) \ 268 .TypeConstraint<type>("T") \ 269 .HostMemory("value_index"), \ 270 MergeOp); 271 REGISTER_SYCL_KERNEL(bool); 272 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); 273 274 #define REGISTER_SYCL_REF_KERNEL(type) \ 275 REGISTER_KERNEL_BUILDER(Name("RefMerge") \ 276 .Device(DEVICE_SYCL) \ 277 .TypeConstraint<type>("T") \ 278 .HostMemory("value_index"), \ 279 MergeOp); 280 REGISTER_SYCL_REF_KERNEL(bool); 281 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL); 282 283 #undef REGISTER_SYCL_KERNEL 284 #undef REGISTER_SYCL_REF_KERNEL 285 #endif // TENSORFLOW_USE_SYCL 286 287 // Special GPU kernels for int32 and string. 288 // TODO(b/25387198): Also enable int32 in device memory. This kernel 289 // registration requires all int32 inputs and outputs to be in host memory. 290 #define REGISTER_GPU_HOST_KERNEL(type) \ 291 REGISTER_KERNEL_BUILDER(Name("Merge") \ 292 .Device(DEVICE_GPU) \ 293 .HostMemory("inputs") \ 294 .HostMemory("output") \ 295 .HostMemory("value_index") \ 296 .TypeConstraint<type>("T"), \ 297 MergeOp); \ 298 REGISTER_KERNEL_BUILDER(Name("RefMerge") \ 299 .Device(DEVICE_GPU) \ 300 .HostMemory("inputs") \ 301 .HostMemory("output") \ 302 .HostMemory("value_index") \ 303 .TypeConstraint<type>("T"), \ 304 MergeOp) 305 306 REGISTER_GPU_HOST_KERNEL(int32); 307 REGISTER_GPU_HOST_KERNEL(string); 308 REGISTER_GPU_HOST_KERNEL(ResourceHandle); 309 310 #undef REGISTER_GPU_HOST_KERNEL 311 312 #ifdef TENSORFLOW_USE_SYCL 313 #define REGISTER_SYCL_HOST_KERNEL(type) \ 314 REGISTER_KERNEL_BUILDER(Name("Merge") \ 315 .Device(DEVICE_SYCL) \ 316 .HostMemory("inputs") \ 317 .HostMemory("output") \ 318 .HostMemory("value_index") \ 319 .TypeConstraint<type>("T"), \ 320 MergeOp); \ 321 REGISTER_KERNEL_BUILDER(Name("RefMerge") \ 322 .Device(DEVICE_SYCL) \ 323 .HostMemory("inputs") \ 324 .HostMemory("output") \ 325 .HostMemory("value_index") \ 326 .TypeConstraint<type>("T"), \ 327 MergeOp) 328 329 REGISTER_SYCL_HOST_KERNEL(int32); 330 REGISTER_SYCL_HOST_KERNEL(string); 331 REGISTER_SYCL_HOST_KERNEL(ResourceHandle); 332 333 #undef REGISTER_SYCL_HOST_KERNEL 334 #endif // TENSORFLOW_USE_SYCL 335 336 void EnterOp::Compute(OpKernelContext* context) { 337 if (IsRefType(context->input_dtype(0))) { 338 context->forward_ref_input_to_ref_output(0, 0); 339 } else { 340 context->set_output(0, context->input(0)); 341 } 342 } 343 344 REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE_CPU), EnterOp); 345 REGISTER_KERNEL_BUILDER(Name("RefEnter").Device(DEVICE_CPU), EnterOp); 346 347 #define REGISTER_GPU_KERNEL(type) \ 348 REGISTER_KERNEL_BUILDER( \ 349 Name("Enter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp) 350 #define REGISTER_GPU_REF_KERNEL(type) \ 351 REGISTER_KERNEL_BUILDER( \ 352 Name("RefEnter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp) 353 354 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); 355 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL); 356 REGISTER_GPU_KERNEL(bool); 357 REGISTER_GPU_REF_KERNEL(bool); 358 359 #undef REGISTER_GPU_KERNEL 360 #undef REGISTER_GPU_REF_KERNEL 361 362 #ifdef TENSORFLOW_USE_SYCL 363 #define REGISTER_SYCL_KERNEL(type) \ 364 REGISTER_KERNEL_BUILDER( \ 365 Name("Enter").Device(DEVICE_SYCL).TypeConstraint<type>("T"), EnterOp) 366 REGISTER_SYCL_KERNEL(bool); 367 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); 368 369 #define REGISTER_SYCL_REF_KERNEL(type) \ 370 REGISTER_KERNEL_BUILDER( \ 371 Name("RefEnter").Device(DEVICE_SYCL).TypeConstraint<type>("T"), EnterOp) 372 REGISTER_SYCL_REF_KERNEL(bool); 373 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL); 374 375 #undef REGISTER_SYCL_KERNEL 376 #undef REGISTER_SYCL_REF_KERNEL 377 #define REGISTER_SYCL_HOST_KERNEL(type) \ 378 REGISTER_KERNEL_BUILDER(Name("Enter") \ 379 .Device(DEVICE_SYCL) \ 380 .HostMemory("data") \ 381 .HostMemory("output") \ 382 .TypeConstraint<type>("T"), \ 383 EnterOp) 384 385 #define REGISTER_SYCL_HOST_REF_KERNEL(type) \ 386 REGISTER_KERNEL_BUILDER(Name("RefEnter") \ 387 .Device(DEVICE_SYCL) \ 388 .HostMemory("data") \ 389 .HostMemory("output") \ 390 .TypeConstraint<type>("T"), \ 391 EnterOp) 392 393 REGISTER_SYCL_HOST_KERNEL(int32); 394 REGISTER_SYCL_HOST_REF_KERNEL(int32); 395 REGISTER_SYCL_HOST_KERNEL(string); 396 REGISTER_SYCL_HOST_REF_KERNEL(string); 397 REGISTER_SYCL_HOST_KERNEL(ResourceHandle); 398 399 #undef REGISTER_SYCL_HOST_KERNEL 400 #undef REGISTER_SYCL_HOST_REF_KERNEL 401 #endif // TENSORFLOW_USE_SYCL 402 403 // Special GPU kernels for int32 and string. 404 // TODO(b/25387198): Also enable int32 in device memory. This kernel 405 // registration requires all int32 inputs and outputs to be in host memory. 406 #define REGISTER_GPU_HOST_KERNEL(type) \ 407 REGISTER_KERNEL_BUILDER(Name("Enter") \ 408 .Device(DEVICE_GPU) \ 409 .HostMemory("data") \ 410 .HostMemory("output") \ 411 .TypeConstraint<type>("T"), \ 412 EnterOp) 413 414 #define REGISTER_GPU_HOST_REF_KERNEL(type) \ 415 REGISTER_KERNEL_BUILDER(Name("RefEnter") \ 416 .Device(DEVICE_GPU) \ 417 .HostMemory("data") \ 418 .HostMemory("output") \ 419 .TypeConstraint<type>("T"), \ 420 EnterOp) 421 422 REGISTER_GPU_HOST_KERNEL(int32); 423 REGISTER_GPU_HOST_REF_KERNEL(int32); 424 REGISTER_GPU_HOST_KERNEL(string); 425 REGISTER_GPU_HOST_REF_KERNEL(string); 426 REGISTER_GPU_HOST_KERNEL(ResourceHandle); 427 428 #undef REGISTER_GPU_HOST_KERNEL 429 #undef REGISTER_GPU_HOST_REF_KERNEL 430 431 void ExitOp::Compute(OpKernelContext* context) { 432 if (IsRefType(context->input_dtype(0))) { 433 context->forward_ref_input_to_ref_output(0, 0); 434 } else { 435 context->set_output(0, context->input(0)); 436 } 437 } 438 439 REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_CPU), ExitOp); 440 REGISTER_KERNEL_BUILDER(Name("RefExit").Device(DEVICE_CPU), ExitOp); 441 442 #define REGISTER_GPU_KERNEL(type) \ 443 REGISTER_KERNEL_BUILDER( \ 444 Name("Exit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp); 445 #define REGISTER_GPU_REF_KERNEL(type) \ 446 REGISTER_KERNEL_BUILDER( \ 447 Name("RefExit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp); 448 449 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); 450 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL); 451 REGISTER_GPU_KERNEL(bool); 452 REGISTER_GPU_REF_KERNEL(bool); 453 454 #undef REGISTER_GPU_KERNEL 455 #undef REGISTER_GPU_REF_KERNEL 456 457 #ifdef TENSORFLOW_USE_SYCL 458 #define REGISTER_SYCL_KERNEL(type) \ 459 REGISTER_KERNEL_BUILDER( \ 460 Name("Exit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp); \ 461 REGISTER_KERNEL_BUILDER( \ 462 Name("RefExit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp); 463 REGISTER_SYCL_KERNEL(bool); 464 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); 465 466 #undef REGISTER_SYCL_KERNEL 467 #undef REGISTER_SYCL_REF_KERNEL 468 469 #define REGISTER_SYCL_HOST_KERNEL(type) \ 470 REGISTER_KERNEL_BUILDER(Name("Exit") \ 471 .Device(DEVICE_SYCL) \ 472 .HostMemory("data") \ 473 .HostMemory("output") \ 474 .TypeConstraint<type>("T"), \ 475 ExitOp); \ 476 REGISTER_KERNEL_BUILDER(Name("RefExit") \ 477 .Device(DEVICE_SYCL) \ 478 .HostMemory("data") \ 479 .HostMemory("output") \ 480 .TypeConstraint<type>("T"), \ 481 ExitOp) 482 483 REGISTER_SYCL_HOST_KERNEL(int32); 484 REGISTER_SYCL_HOST_KERNEL(string); 485 #undef REGISTER_SYCL_HOST_KERNEL 486 #endif // TENSORFLOW_USE_SYCL 487 488 // Special GPU kernels for int32 and string. 489 // TODO(b/25387198): Also enable int32 in device memory. This kernel 490 // registration requires all int32 inputs and outputs to be in host memory. 491 #define REGISTER_GPU_HOST_KERNEL(type) \ 492 REGISTER_KERNEL_BUILDER(Name("Exit") \ 493 .Device(DEVICE_GPU) \ 494 .HostMemory("data") \ 495 .HostMemory("output") \ 496 .TypeConstraint<type>("T"), \ 497 ExitOp); \ 498 REGISTER_KERNEL_BUILDER(Name("RefExit") \ 499 .Device(DEVICE_GPU) \ 500 .HostMemory("data") \ 501 .HostMemory("output") \ 502 .TypeConstraint<type>("T"), \ 503 ExitOp) 504 505 REGISTER_GPU_HOST_KERNEL(int32); 506 REGISTER_GPU_HOST_KERNEL(string); 507 508 #undef REGISTER_GPU_HOST_KERNEL 509 510 void NextIterationOp::Compute(OpKernelContext* context) { 511 if (IsRefType(context->input_dtype(0))) { 512 context->forward_ref_input_to_ref_output(0, 0); 513 } else { 514 context->set_output(0, context->input(0)); 515 } 516 } 517 518 REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_CPU), 519 NextIterationOp); 520 REGISTER_KERNEL_BUILDER(Name("RefNextIteration").Device(DEVICE_CPU), 521 NextIterationOp); 522 523 #define REGISTER_GPU_KERNEL(type) \ 524 REGISTER_KERNEL_BUILDER( \ 525 Name("NextIteration").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 526 NextIterationOp); \ 527 REGISTER_KERNEL_BUILDER( \ 528 Name("RefNextIteration").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 529 NextIterationOp) 530 531 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); 532 REGISTER_GPU_KERNEL(bool); 533 534 #undef REGISTER_GPU_KERNEL 535 536 // Special GPU kernels for int32 and string. 537 // TODO(b/25387198): Also enable int32 in device memory. This kernel 538 // registration requires all int32 inputs and outputs to be in host memory. 539 #define REGISTER_GPU_HOST_KERNEL(type) \ 540 REGISTER_KERNEL_BUILDER(Name("NextIteration") \ 541 .Device(DEVICE_GPU) \ 542 .HostMemory("data") \ 543 .HostMemory("output") \ 544 .TypeConstraint<type>("T"), \ 545 NextIterationOp); \ 546 REGISTER_KERNEL_BUILDER(Name("RefNextIteration") \ 547 .Device(DEVICE_GPU) \ 548 .HostMemory("data") \ 549 .HostMemory("output") \ 550 .TypeConstraint<type>("T"), \ 551 NextIterationOp) 552 553 REGISTER_GPU_HOST_KERNEL(int32); 554 REGISTER_GPU_HOST_KERNEL(string); 555 556 #undef REGISTER_GPU_HOST_KERNEL 557 558 #ifdef TENSORFLOW_USE_SYCL 559 #define REGISTER_SYCL_KERNEL(type) \ 560 REGISTER_KERNEL_BUILDER( \ 561 Name("NextIteration").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ 562 NextIterationOp); \ 563 REGISTER_KERNEL_BUILDER( \ 564 Name("RefNextIteration").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ 565 NextIterationOp) 566 REGISTER_SYCL_KERNEL(bool); 567 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); 568 569 #undef REGISTER_SYCL_KERNEL 570 571 #define REGISTER_SYCL_HOST_KERNEL(type) \ 572 REGISTER_KERNEL_BUILDER(Name("NextIteration") \ 573 .Device(DEVICE_SYCL) \ 574 .HostMemory("data") \ 575 .HostMemory("output") \ 576 .TypeConstraint<type>("T"), \ 577 NextIterationOp); \ 578 REGISTER_KERNEL_BUILDER(Name("RefNextIteration") \ 579 .Device(DEVICE_SYCL) \ 580 .HostMemory("data") \ 581 .HostMemory("output") \ 582 .TypeConstraint<type>("T"), \ 583 NextIterationOp) 584 585 REGISTER_SYCL_HOST_KERNEL(int32); 586 REGISTER_SYCL_HOST_KERNEL(string); 587 #undef REGISTER_SYCL_HOST_KERNEL 588 #endif // TENSORFLOW_USE_SYCL 589 590 // A LoopCond op has one input and one output. The input is a boolean 591 // scalar representing the taken branches of the "pivot" Switch that 592 // determines loop termination. As a contract, any high-level front-end 593 // should always use port '0' of the "pivot" switches for loop exit. 594 class LoopCondOp : public OpKernel { 595 public: 596 explicit LoopCondOp(OpKernelConstruction* context) : OpKernel(context) {} 597 598 void Compute(OpKernelContext* context) override { 599 context->set_output(0, context->input(0)); 600 } 601 602 bool IsExpensive() override { return false; } 603 604 ~LoopCondOp() override {} 605 606 TF_DISALLOW_COPY_AND_ASSIGN(LoopCondOp); 607 }; 608 609 REGISTER_KERNEL_BUILDER(Name("LoopCond").Device(DEVICE_CPU), LoopCondOp); 610 REGISTER_KERNEL_BUILDER(Name("LoopCond") 611 .Device(DEVICE_GPU) 612 .HostMemory("input") 613 .HostMemory("output"), 614 LoopCondOp); 615 616 #ifdef TENSORFLOW_USE_SYCL 617 REGISTER_KERNEL_BUILDER(Name("LoopCond") 618 .Device(DEVICE_SYCL) 619 .HostMemory("input") 620 .HostMemory("output"), 621 LoopCondOp); 622 #endif // TENSORFLOW_USE_SYCL 623 624 // ControlTrigger kernels 625 REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_CPU), 626 ControlTriggerOp); 627 628 REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_GPU), 629 ControlTriggerOp); 630 631 #ifdef TENSORFLOW_USE_SYCL 632 REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_SYCL), 633 ControlTriggerOp); 634 #endif // TENSORFLOW_USE_SYCL 635 636 // When called, abort op will abort the current process. This can be used to 637 // abort remote PSs when needed. 638 class AbortOp : public OpKernel { 639 public: 640 explicit AbortOp(OpKernelConstruction* context) : OpKernel(context) { 641 OP_REQUIRES_OK(context, context->GetAttr("error_msg", &error_msg_)); 642 OP_REQUIRES_OK( 643 context, context->GetAttr("exit_without_error", &exit_without_error_)); 644 } 645 646 void Compute(OpKernelContext* context) override { 647 if (!exit_without_error_) { 648 LOG(FATAL) << "Abort_op intentional failure; " << error_msg_; 649 } else { 650 LOG(WARNING) << "Exiting the process: " << error_msg_; 651 exit(0); 652 } 653 } 654 655 private: 656 string error_msg_; 657 bool exit_without_error_; 658 }; 659 660 REGISTER_KERNEL_BUILDER(Name("Abort").Device(DEVICE_CPU), AbortOp); 661 662 } // namespace tensorflow 663