Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/kernels/control_flow_ops.h"
     17 
     18 #include "tensorflow/core/framework/op_kernel.h"
     19 #include "tensorflow/core/framework/register_types.h"
     20 #include "tensorflow/core/framework/tensor.h"
     21 #include "tensorflow/core/framework/types.h"
     22 #include "tensorflow/core/platform/macros.h"
     23 
     24 namespace tensorflow {
     25 
     26 void SwitchOp::Compute(OpKernelContext* context) {
     27   const Tensor& outputPorts = context->input(1);
     28   OP_REQUIRES(context, TensorShapeUtils::IsScalar(outputPorts.shape()),
     29               errors::InvalidArgument("The second input must be a scalar, "
     30                                       "but it has shape ",
     31                                       outputPorts.shape().DebugString()));
     32 
     33   bool pred = outputPorts.scalar<bool>()();
     34   int port = (pred) ? 1 : 0;
     35   if (context->input_is_ref(0)) {
     36     context->forward_ref_input_to_ref_output(0, port);
     37   } else {
     38     context->set_output(port, context->input(0));
     39   }
     40 }
     41 
     42 #define REGISTER_CPU_SWITCH(type)                         \
     43   REGISTER_KERNEL_BUILDER(Name("Switch")                  \
     44                               .Device(DEVICE_CPU)         \
     45                               .HostMemory("pred")         \
     46                               .TypeConstraint<type>("T"), \
     47                           SwitchOp)
     48 
     49 #define REGISTER_CPU_REF_SWITCH(type)                     \
     50   REGISTER_KERNEL_BUILDER(Name("RefSwitch")               \
     51                               .Device(DEVICE_CPU)         \
     52                               .HostMemory("pred")         \
     53                               .TypeConstraint<type>("T"), \
     54                           SwitchOp)
     55 
     56 #define REGISTER_GPU_SWITCH(type)                         \
     57   REGISTER_KERNEL_BUILDER(Name("Switch")                  \
     58                               .Device(DEVICE_GPU)         \
     59                               .HostMemory("pred")         \
     60                               .TypeConstraint<type>("T"), \
     61                           SwitchOp)
     62 
     63 #define REGISTER_GPU_REF_SWITCH(type)                     \
     64   REGISTER_KERNEL_BUILDER(Name("RefSwitch")               \
     65                               .Device(DEVICE_GPU)         \
     66                               .HostMemory("pred")         \
     67                               .TypeConstraint<type>("T"), \
     68                           SwitchOp)
     69 
     70 TF_CALL_ALL_TYPES(REGISTER_CPU_SWITCH);
     71 TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SWITCH);
     72 TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_SWITCH);
     73 
     74 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_SWITCH);
     75 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_SWITCH);
     76 
     77 #undef REGISTER_CPU_SWITCH
     78 #undef REGISTER_CPU_REF_SWITCH
     79 #undef REGISTER_GPU_SWITCH
     80 #undef REGISTER_GPU_REF_SWITCH
     81 
     82 // Special GPU kernels for int32 and string.
     83 // TODO(b/25387198): Also enable int32 in device memory. This kernel
     84 // registration requires all int32 inputs and outputs to be in host memory.
     85 #define REGISTER_GPU_HOST_KERNEL(type)                    \
     86   REGISTER_KERNEL_BUILDER(Name("Switch")                  \
     87                               .Device(DEVICE_GPU)         \
     88                               .HostMemory("data")         \
     89                               .HostMemory("pred")         \
     90                               .HostMemory("output_false") \
     91                               .HostMemory("output_true")  \
     92                               .TypeConstraint<type>("T"), \
     93                           SwitchOp)
     94 
     95 #define REGISTER_GPU_HOST_REF_KERNEL(type)                \
     96   REGISTER_KERNEL_BUILDER(Name("RefSwitch")               \
     97                               .Device(DEVICE_GPU)         \
     98                               .HostMemory("data")         \
     99                               .HostMemory("pred")         \
    100                               .HostMemory("output_false") \
    101                               .HostMemory("output_true")  \
    102                               .TypeConstraint<type>("T"), \
    103                           SwitchOp)
    104 
    105 REGISTER_GPU_HOST_KERNEL(int32);
    106 REGISTER_GPU_HOST_REF_KERNEL(int32);
    107 REGISTER_GPU_HOST_KERNEL(bool);
    108 REGISTER_GPU_HOST_REF_KERNEL(bool);
    109 REGISTER_GPU_HOST_KERNEL(string);
    110 REGISTER_GPU_HOST_REF_KERNEL(string);
    111 
    112 #undef REGISTER_GPU_HOST_KERNEL
    113 #undef REGISTER_GPU_HOST_REF_KERNEL
    114 
    115 #ifdef TENSORFLOW_USE_SYCL
    116 #define REGISTER_SYCL_SWITCH(type)                        \
    117   REGISTER_KERNEL_BUILDER(Name("Switch")                  \
    118                               .Device(DEVICE_SYCL)        \
    119                               .HostMemory("pred")         \
    120                               .TypeConstraint<type>("T"), \
    121                           SwitchOp)
    122 TF_CALL_REAL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_SWITCH);
    123 
    124 #define REGISTER_SYCL_REF_SWITCH(type)                    \
    125   REGISTER_KERNEL_BUILDER(Name("RefSwitch")               \
    126                               .Device(DEVICE_SYCL)        \
    127                               .HostMemory("pred")         \
    128                               .TypeConstraint<type>("T"), \
    129                           SwitchOp)
    130 TF_CALL_REAL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_SWITCH);
    131 
    132 #undef REGISTER_SYCL_SWITCH
    133 #undef REGISTER_SYCL_REF_SWITCH
    134 
    135 #define REGISTER_SYCL_HOST_KERNEL(type)                   \
    136   REGISTER_KERNEL_BUILDER(Name("Switch")                  \
    137                               .Device(DEVICE_SYCL)        \
    138                               .HostMemory("data")         \
    139                               .HostMemory("pred")         \
    140                               .HostMemory("output_false") \
    141                               .HostMemory("output_true")  \
    142                               .TypeConstraint<type>("T"), \
    143                           SwitchOp)
    144 
    145 REGISTER_SYCL_HOST_KERNEL(bool);
    146 REGISTER_SYCL_HOST_KERNEL(string);
    147 REGISTER_SYCL_HOST_KERNEL(int32);
    148 
    149 #define REGISTER_SYCL_HOST_REF_KERNEL(type)               \
    150   REGISTER_KERNEL_BUILDER(Name("RefSwitch")               \
    151                               .Device(DEVICE_SYCL)        \
    152                               .HostMemory("data")         \
    153                               .HostMemory("pred")         \
    154                               .HostMemory("output_false") \
    155                               .HostMemory("output_true")  \
    156                               .TypeConstraint<type>("T"), \
    157                           SwitchOp)
    158 
    159 REGISTER_SYCL_HOST_REF_KERNEL(int32);
    160 REGISTER_SYCL_HOST_REF_KERNEL(bool);
    161 REGISTER_SYCL_HOST_REF_KERNEL(string);
    162 
    163 #undef REGISTER_SYCL_HOST_KERNEL
    164 #undef REGISTER_SYCL_HOST_REF_KERNEL
    165 #endif  // TENSORFLOW_USE_SYCL
    166 
    167 class RefSelectOp : public OpKernel {
    168  public:
    169   explicit RefSelectOp(OpKernelConstruction* context) : OpKernel(context) {
    170     OP_REQUIRES_OK(context, context->GetAttr("N", &num_ref_inputs_));
    171   }
    172 
    173   void Compute(OpKernelContext* context) override {
    174     const Tensor& index_tensor = context->input(0);
    175     OP_REQUIRES(context, TensorShapeUtils::IsScalar(index_tensor.shape()),
    176                 errors::InvalidArgument("Index must be a scalar, "
    177                                         "but it has shape ",
    178                                         index_tensor.shape().DebugString()));
    179 
    180     int32 index = index_tensor.scalar<int32>()();
    181 
    182     OP_REQUIRES(context, index >= 0 && index < num_ref_inputs_,
    183                 errors::InvalidArgument("Index must be in the range [0, ",
    184                                         num_ref_inputs_, ") but got ", index));
    185     context->forward_ref_input_to_ref_output(index + 1, 0);
    186   }
    187 
    188   bool IsExpensive() override { return false; }
    189 
    190   ~RefSelectOp() override {}
    191 
    192   TF_DISALLOW_COPY_AND_ASSIGN(RefSelectOp);
    193 
    194  private:
    195   int num_ref_inputs_;
    196 };
    197 
    198 #define REGISTER_CPU_REF_SELECT(type)                     \
    199   REGISTER_KERNEL_BUILDER(Name("RefSelect")               \
    200                               .Device(DEVICE_CPU)         \
    201                               .HostMemory("index")        \
    202                               .TypeConstraint<type>("T"), \
    203                           RefSelectOp)
    204 TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SELECT);
    205 
    206 #undef REGISTER_CPU_REF_SWITCH
    207 
    208 MergeOp::MergeOp(OpKernelConstruction* context) : OpKernel(context) {
    209   const DataType dt = context->input_type(0);
    210   const int num_in = context->num_inputs();
    211   OP_REQUIRES_OK(context, context->MatchSignature(DataTypeVector(num_in, dt),
    212                                                   {dt, DT_INT32}));
    213 }
    214 
    215 void MergeOp::Compute(OpKernelContext* context) {
    216   bool input_seen = false;
    217   for (int i = 0; i < context->num_inputs(); ++i) {
    218     if (context->has_input(i)) {
    219       if (input_seen) {
    220         context->SetStatus(
    221             errors::Internal("Merge can not have more than one valid input."));
    222         return;
    223       }
    224       input_seen = true;
    225 
    226       if (IsRefType(context->input_dtype(i))) {
    227         context->forward_ref_input_to_ref_output(i, 0);
    228       } else {
    229         context->set_output(0, context->input(i));
    230       }
    231       Tensor* value_index = nullptr;
    232       OP_REQUIRES_OK(
    233           context, context->allocate_output(1, TensorShape({}), &value_index));
    234       value_index->scalar<int32>()() = i;
    235     }
    236   }
    237 }
    238 
    239 REGISTER_KERNEL_BUILDER(Name("Merge").Device(DEVICE_CPU), MergeOp);
    240 REGISTER_KERNEL_BUILDER(Name("RefMerge").Device(DEVICE_CPU), MergeOp);
    241 
    242 #define REGISTER_GPU_KERNEL(type)                         \
    243   REGISTER_KERNEL_BUILDER(Name("Merge")                   \
    244                               .Device(DEVICE_GPU)         \
    245                               .TypeConstraint<type>("T")  \
    246                               .HostMemory("value_index"), \
    247                           MergeOp);
    248 
    249 #define REGISTER_GPU_REF_KERNEL(type)                     \
    250   REGISTER_KERNEL_BUILDER(Name("RefMerge")                \
    251                               .Device(DEVICE_GPU)         \
    252                               .TypeConstraint<type>("T")  \
    253                               .HostMemory("value_index"), \
    254                           MergeOp);
    255 
    256 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
    257 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL);
    258 REGISTER_GPU_KERNEL(bool);
    259 REGISTER_GPU_REF_KERNEL(bool);
    260 
    261 #undef REGISTER_GPU_KERNEL
    262 #undef REGISTER_GPU_REF_KERNEL
    263 
    264 #ifdef TENSORFLOW_USE_SYCL
    265 #define REGISTER_SYCL_KERNEL(type)                        \
    266   REGISTER_KERNEL_BUILDER(Name("Merge")                   \
    267                               .Device(DEVICE_SYCL)        \
    268                               .TypeConstraint<type>("T")  \
    269                               .HostMemory("value_index"), \
    270                           MergeOp);
    271 REGISTER_SYCL_KERNEL(bool);
    272 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
    273 
    274 #define REGISTER_SYCL_REF_KERNEL(type)                    \
    275   REGISTER_KERNEL_BUILDER(Name("RefMerge")                \
    276                               .Device(DEVICE_SYCL)        \
    277                               .TypeConstraint<type>("T")  \
    278                               .HostMemory("value_index"), \
    279                           MergeOp);
    280 REGISTER_SYCL_REF_KERNEL(bool);
    281 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL);
    282 
    283 #undef REGISTER_SYCL_KERNEL
    284 #undef REGISTER_SYCL_REF_KERNEL
    285 #endif  // TENSORFLOW_USE_SYCL
    286 
    287 // Special GPU kernels for int32 and string.
    288 // TODO(b/25387198): Also enable int32 in device memory. This kernel
    289 // registration requires all int32 inputs and outputs to be in host memory.
    290 #define REGISTER_GPU_HOST_KERNEL(type)                    \
    291   REGISTER_KERNEL_BUILDER(Name("Merge")                   \
    292                               .Device(DEVICE_GPU)         \
    293                               .HostMemory("inputs")       \
    294                               .HostMemory("output")       \
    295                               .HostMemory("value_index")  \
    296                               .TypeConstraint<type>("T"), \
    297                           MergeOp);                       \
    298   REGISTER_KERNEL_BUILDER(Name("RefMerge")                \
    299                               .Device(DEVICE_GPU)         \
    300                               .HostMemory("inputs")       \
    301                               .HostMemory("output")       \
    302                               .HostMemory("value_index")  \
    303                               .TypeConstraint<type>("T"), \
    304                           MergeOp)
    305 
    306 REGISTER_GPU_HOST_KERNEL(int32);
    307 REGISTER_GPU_HOST_KERNEL(string);
    308 REGISTER_GPU_HOST_KERNEL(ResourceHandle);
    309 
    310 #undef REGISTER_GPU_HOST_KERNEL
    311 
    312 #ifdef TENSORFLOW_USE_SYCL
    313 #define REGISTER_SYCL_HOST_KERNEL(type)                   \
    314   REGISTER_KERNEL_BUILDER(Name("Merge")                   \
    315                               .Device(DEVICE_SYCL)        \
    316                               .HostMemory("inputs")       \
    317                               .HostMemory("output")       \
    318                               .HostMemory("value_index")  \
    319                               .TypeConstraint<type>("T"), \
    320                           MergeOp);                       \
    321   REGISTER_KERNEL_BUILDER(Name("RefMerge")                \
    322                               .Device(DEVICE_SYCL)        \
    323                               .HostMemory("inputs")       \
    324                               .HostMemory("output")       \
    325                               .HostMemory("value_index")  \
    326                               .TypeConstraint<type>("T"), \
    327                           MergeOp)
    328 
    329 REGISTER_SYCL_HOST_KERNEL(int32);
    330 REGISTER_SYCL_HOST_KERNEL(string);
    331 REGISTER_SYCL_HOST_KERNEL(ResourceHandle);
    332 
    333 #undef REGISTER_SYCL_HOST_KERNEL
    334 #endif  // TENSORFLOW_USE_SYCL
    335 
    336 void EnterOp::Compute(OpKernelContext* context) {
    337   if (IsRefType(context->input_dtype(0))) {
    338     context->forward_ref_input_to_ref_output(0, 0);
    339   } else {
    340     context->set_output(0, context->input(0));
    341   }
    342 }
    343 
    344 REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE_CPU), EnterOp);
    345 REGISTER_KERNEL_BUILDER(Name("RefEnter").Device(DEVICE_CPU), EnterOp);
    346 
    347 #define REGISTER_GPU_KERNEL(type) \
    348   REGISTER_KERNEL_BUILDER(        \
    349       Name("Enter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp)
    350 #define REGISTER_GPU_REF_KERNEL(type) \
    351   REGISTER_KERNEL_BUILDER(            \
    352       Name("RefEnter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp)
    353 
    354 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
    355 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL);
    356 REGISTER_GPU_KERNEL(bool);
    357 REGISTER_GPU_REF_KERNEL(bool);
    358 
    359 #undef REGISTER_GPU_KERNEL
    360 #undef REGISTER_GPU_REF_KERNEL
    361 
    362 #ifdef TENSORFLOW_USE_SYCL
    363 #define REGISTER_SYCL_KERNEL(type) \
    364   REGISTER_KERNEL_BUILDER(         \
    365       Name("Enter").Device(DEVICE_SYCL).TypeConstraint<type>("T"), EnterOp)
    366 REGISTER_SYCL_KERNEL(bool);
    367 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
    368 
    369 #define REGISTER_SYCL_REF_KERNEL(type) \
    370   REGISTER_KERNEL_BUILDER(             \
    371       Name("RefEnter").Device(DEVICE_SYCL).TypeConstraint<type>("T"), EnterOp)
    372 REGISTER_SYCL_REF_KERNEL(bool);
    373 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL);
    374 
    375 #undef REGISTER_SYCL_KERNEL
    376 #undef REGISTER_SYCL_REF_KERNEL
    377 #define REGISTER_SYCL_HOST_KERNEL(type)                   \
    378   REGISTER_KERNEL_BUILDER(Name("Enter")                   \
    379                               .Device(DEVICE_SYCL)        \
    380                               .HostMemory("data")         \
    381                               .HostMemory("output")       \
    382                               .TypeConstraint<type>("T"), \
    383                           EnterOp)
    384 
    385 #define REGISTER_SYCL_HOST_REF_KERNEL(type)               \
    386   REGISTER_KERNEL_BUILDER(Name("RefEnter")                \
    387                               .Device(DEVICE_SYCL)        \
    388                               .HostMemory("data")         \
    389                               .HostMemory("output")       \
    390                               .TypeConstraint<type>("T"), \
    391                           EnterOp)
    392 
    393 REGISTER_SYCL_HOST_KERNEL(int32);
    394 REGISTER_SYCL_HOST_REF_KERNEL(int32);
    395 REGISTER_SYCL_HOST_KERNEL(string);
    396 REGISTER_SYCL_HOST_REF_KERNEL(string);
    397 REGISTER_SYCL_HOST_KERNEL(ResourceHandle);
    398 
    399 #undef REGISTER_SYCL_HOST_KERNEL
    400 #undef REGISTER_SYCL_HOST_REF_KERNEL
    401 #endif  // TENSORFLOW_USE_SYCL
    402 
    403 // Special GPU kernels for int32 and string.
    404 // TODO(b/25387198): Also enable int32 in device memory. This kernel
    405 // registration requires all int32 inputs and outputs to be in host memory.
    406 #define REGISTER_GPU_HOST_KERNEL(type)                    \
    407   REGISTER_KERNEL_BUILDER(Name("Enter")                   \
    408                               .Device(DEVICE_GPU)         \
    409                               .HostMemory("data")         \
    410                               .HostMemory("output")       \
    411                               .TypeConstraint<type>("T"), \
    412                           EnterOp)
    413 
    414 #define REGISTER_GPU_HOST_REF_KERNEL(type)                \
    415   REGISTER_KERNEL_BUILDER(Name("RefEnter")                \
    416                               .Device(DEVICE_GPU)         \
    417                               .HostMemory("data")         \
    418                               .HostMemory("output")       \
    419                               .TypeConstraint<type>("T"), \
    420                           EnterOp)
    421 
    422 REGISTER_GPU_HOST_KERNEL(int32);
    423 REGISTER_GPU_HOST_REF_KERNEL(int32);
    424 REGISTER_GPU_HOST_KERNEL(string);
    425 REGISTER_GPU_HOST_REF_KERNEL(string);
    426 REGISTER_GPU_HOST_KERNEL(ResourceHandle);
    427 
    428 #undef REGISTER_GPU_HOST_KERNEL
    429 #undef REGISTER_GPU_HOST_REF_KERNEL
    430 
    431 void ExitOp::Compute(OpKernelContext* context) {
    432   if (IsRefType(context->input_dtype(0))) {
    433     context->forward_ref_input_to_ref_output(0, 0);
    434   } else {
    435     context->set_output(0, context->input(0));
    436   }
    437 }
    438 
    439 REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_CPU), ExitOp);
    440 REGISTER_KERNEL_BUILDER(Name("RefExit").Device(DEVICE_CPU), ExitOp);
    441 
    442 #define REGISTER_GPU_KERNEL(type) \
    443   REGISTER_KERNEL_BUILDER(        \
    444       Name("Exit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp);
    445 #define REGISTER_GPU_REF_KERNEL(type) \
    446   REGISTER_KERNEL_BUILDER(            \
    447       Name("RefExit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp);
    448 
    449 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
    450 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL);
    451 REGISTER_GPU_KERNEL(bool);
    452 REGISTER_GPU_REF_KERNEL(bool);
    453 
    454 #undef REGISTER_GPU_KERNEL
    455 #undef REGISTER_GPU_REF_KERNEL
    456 
    457 #ifdef TENSORFLOW_USE_SYCL
    458 #define REGISTER_SYCL_KERNEL(type)                                         \
    459   REGISTER_KERNEL_BUILDER(                                                 \
    460       Name("Exit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp); \
    461   REGISTER_KERNEL_BUILDER(                                                 \
    462       Name("RefExit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp);
    463 REGISTER_SYCL_KERNEL(bool);
    464 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
    465 
    466 #undef REGISTER_SYCL_KERNEL
    467 #undef REGISTER_SYCL_REF_KERNEL
    468 
    469 #define REGISTER_SYCL_HOST_KERNEL(type)                   \
    470   REGISTER_KERNEL_BUILDER(Name("Exit")                    \
    471                               .Device(DEVICE_SYCL)        \
    472                               .HostMemory("data")         \
    473                               .HostMemory("output")       \
    474                               .TypeConstraint<type>("T"), \
    475                           ExitOp);                        \
    476   REGISTER_KERNEL_BUILDER(Name("RefExit")                 \
    477                               .Device(DEVICE_SYCL)        \
    478                               .HostMemory("data")         \
    479                               .HostMemory("output")       \
    480                               .TypeConstraint<type>("T"), \
    481                           ExitOp)
    482 
    483 REGISTER_SYCL_HOST_KERNEL(int32);
    484 REGISTER_SYCL_HOST_KERNEL(string);
    485 #undef REGISTER_SYCL_HOST_KERNEL
    486 #endif  // TENSORFLOW_USE_SYCL
    487 
    488 // Special GPU kernels for int32 and string.
    489 // TODO(b/25387198): Also enable int32 in device memory. This kernel
    490 // registration requires all int32 inputs and outputs to be in host memory.
    491 #define REGISTER_GPU_HOST_KERNEL(type)                    \
    492   REGISTER_KERNEL_BUILDER(Name("Exit")                    \
    493                               .Device(DEVICE_GPU)         \
    494                               .HostMemory("data")         \
    495                               .HostMemory("output")       \
    496                               .TypeConstraint<type>("T"), \
    497                           ExitOp);                        \
    498   REGISTER_KERNEL_BUILDER(Name("RefExit")                 \
    499                               .Device(DEVICE_GPU)         \
    500                               .HostMemory("data")         \
    501                               .HostMemory("output")       \
    502                               .TypeConstraint<type>("T"), \
    503                           ExitOp)
    504 
    505 REGISTER_GPU_HOST_KERNEL(int32);
    506 REGISTER_GPU_HOST_KERNEL(string);
    507 
    508 #undef REGISTER_GPU_HOST_KERNEL
    509 
    510 void NextIterationOp::Compute(OpKernelContext* context) {
    511   if (IsRefType(context->input_dtype(0))) {
    512     context->forward_ref_input_to_ref_output(0, 0);
    513   } else {
    514     context->set_output(0, context->input(0));
    515   }
    516 }
    517 
    518 REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_CPU),
    519                         NextIterationOp);
    520 REGISTER_KERNEL_BUILDER(Name("RefNextIteration").Device(DEVICE_CPU),
    521                         NextIterationOp);
    522 
    523 #define REGISTER_GPU_KERNEL(type)                                            \
    524   REGISTER_KERNEL_BUILDER(                                                   \
    525       Name("NextIteration").Device(DEVICE_GPU).TypeConstraint<type>("T"),    \
    526       NextIterationOp);                                                      \
    527   REGISTER_KERNEL_BUILDER(                                                   \
    528       Name("RefNextIteration").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
    529       NextIterationOp)
    530 
    531 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
    532 REGISTER_GPU_KERNEL(bool);
    533 
    534 #undef REGISTER_GPU_KERNEL
    535 
    536 // Special GPU kernels for int32 and string.
    537 // TODO(b/25387198): Also enable int32 in device memory. This kernel
    538 // registration requires all int32 inputs and outputs to be in host memory.
    539 #define REGISTER_GPU_HOST_KERNEL(type)                    \
    540   REGISTER_KERNEL_BUILDER(Name("NextIteration")           \
    541                               .Device(DEVICE_GPU)         \
    542                               .HostMemory("data")         \
    543                               .HostMemory("output")       \
    544                               .TypeConstraint<type>("T"), \
    545                           NextIterationOp);               \
    546   REGISTER_KERNEL_BUILDER(Name("RefNextIteration")        \
    547                               .Device(DEVICE_GPU)         \
    548                               .HostMemory("data")         \
    549                               .HostMemory("output")       \
    550                               .TypeConstraint<type>("T"), \
    551                           NextIterationOp)
    552 
    553 REGISTER_GPU_HOST_KERNEL(int32);
    554 REGISTER_GPU_HOST_KERNEL(string);
    555 
    556 #undef REGISTER_GPU_HOST_KERNEL
    557 
    558 #ifdef TENSORFLOW_USE_SYCL
    559 #define REGISTER_SYCL_KERNEL(type)                                            \
    560   REGISTER_KERNEL_BUILDER(                                                    \
    561       Name("NextIteration").Device(DEVICE_SYCL).TypeConstraint<type>("T"),    \
    562       NextIterationOp);                                                       \
    563   REGISTER_KERNEL_BUILDER(                                                    \
    564       Name("RefNextIteration").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
    565       NextIterationOp)
    566 REGISTER_SYCL_KERNEL(bool);
    567 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
    568 
    569 #undef REGISTER_SYCL_KERNEL
    570 
    571 #define REGISTER_SYCL_HOST_KERNEL(type)                   \
    572   REGISTER_KERNEL_BUILDER(Name("NextIteration")           \
    573                               .Device(DEVICE_SYCL)        \
    574                               .HostMemory("data")         \
    575                               .HostMemory("output")       \
    576                               .TypeConstraint<type>("T"), \
    577                           NextIterationOp);               \
    578   REGISTER_KERNEL_BUILDER(Name("RefNextIteration")        \
    579                               .Device(DEVICE_SYCL)        \
    580                               .HostMemory("data")         \
    581                               .HostMemory("output")       \
    582                               .TypeConstraint<type>("T"), \
    583                           NextIterationOp)
    584 
    585 REGISTER_SYCL_HOST_KERNEL(int32);
    586 REGISTER_SYCL_HOST_KERNEL(string);
    587 #undef REGISTER_SYCL_HOST_KERNEL
    588 #endif  // TENSORFLOW_USE_SYCL
    589 
    590 // A LoopCond op has one input and one output. The input is a boolean
    591 // scalar representing the taken branches of the "pivot" Switch that
    592 // determines loop termination. As a contract, any high-level front-end
    593 // should always use port '0' of the "pivot" switches for loop exit.
    594 class LoopCondOp : public OpKernel {
    595  public:
    596   explicit LoopCondOp(OpKernelConstruction* context) : OpKernel(context) {}
    597 
    598   void Compute(OpKernelContext* context) override {
    599     context->set_output(0, context->input(0));
    600   }
    601 
    602   bool IsExpensive() override { return false; }
    603 
    604   ~LoopCondOp() override {}
    605 
    606   TF_DISALLOW_COPY_AND_ASSIGN(LoopCondOp);
    607 };
    608 
    609 REGISTER_KERNEL_BUILDER(Name("LoopCond").Device(DEVICE_CPU), LoopCondOp);
    610 REGISTER_KERNEL_BUILDER(Name("LoopCond")
    611                             .Device(DEVICE_GPU)
    612                             .HostMemory("input")
    613                             .HostMemory("output"),
    614                         LoopCondOp);
    615 
    616 #ifdef TENSORFLOW_USE_SYCL
    617 REGISTER_KERNEL_BUILDER(Name("LoopCond")
    618                             .Device(DEVICE_SYCL)
    619                             .HostMemory("input")
    620                             .HostMemory("output"),
    621                         LoopCondOp);
    622 #endif  // TENSORFLOW_USE_SYCL
    623 
    624 // ControlTrigger kernels
    625 REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_CPU),
    626                         ControlTriggerOp);
    627 
    628 REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_GPU),
    629                         ControlTriggerOp);
    630 
    631 #ifdef TENSORFLOW_USE_SYCL
    632 REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_SYCL),
    633                         ControlTriggerOp);
    634 #endif  // TENSORFLOW_USE_SYCL
    635 
    636 // When called, abort op will abort the current process. This can be used to
    637 // abort remote PSs when needed.
    638 class AbortOp : public OpKernel {
    639  public:
    640   explicit AbortOp(OpKernelConstruction* context) : OpKernel(context) {
    641     OP_REQUIRES_OK(context, context->GetAttr("error_msg", &error_msg_));
    642     OP_REQUIRES_OK(
    643         context, context->GetAttr("exit_without_error", &exit_without_error_));
    644   }
    645 
    646   void Compute(OpKernelContext* context) override {
    647     if (!exit_without_error_) {
    648       LOG(FATAL) << "Abort_op intentional failure; " << error_msg_;
    649     } else {
    650       LOG(WARNING) << "Exiting the process: " << error_msg_;
    651       exit(0);
    652     }
    653   }
    654 
    655  private:
    656   string error_msg_;
    657   bool exit_without_error_;
    658 };
    659 
    660 REGISTER_KERNEL_BUILDER(Name("Abort").Device(DEVICE_CPU), AbortOp);
    661 
    662 }  // namespace tensorflow
    663