Home | History | Annotate | Download | only in framework
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_FRAMEWORK_VARIANT_OP_REGISTRY_H_
     17 #define TENSORFLOW_FRAMEWORK_VARIANT_OP_REGISTRY_H_
     18 
     19 #include <string>
     20 #include <unordered_set>
     21 #include <vector>
     22 
     23 #define EIGEN_USE_THREADS
     24 
     25 #include "tensorflow/core/framework/types.h"
     26 #include "tensorflow/core/framework/variant.h"
     27 #include "tensorflow/core/framework/variant_encode_decode.h"
     28 #include "tensorflow/core/lib/hash/hash.h"
     29 
     30 namespace tensorflow {
     31 
     32 class OpKernelContext;
     33 // A global UnaryVariantOpRegistry is used to hold callback functions
     34 // for different variant types.  To be used by ShapeOp, RankOp, and
     35 // SizeOp, decoding, etc.
     36 
     37 enum VariantUnaryOp {
     38   INVALID_VARIANT_UNARY_OP = 0,
     39   ZEROS_LIKE_VARIANT_UNARY_OP = 1,
     40   CONJ_VARIANT_UNARY_OP = 2,
     41 };
     42 
     43 enum VariantBinaryOp {
     44   INVALID_VARIANT_BINARY_OP = 0,
     45   ADD_VARIANT_BINARY_OP = 1,
     46 };
     47 
     48 enum VariantDeviceCopyDirection {
     49   INVALID_DEVICE_COPY_DIRECTION = 0,
     50   HOST_TO_DEVICE = 1,
     51   DEVICE_TO_HOST = 2,
     52   DEVICE_TO_DEVICE = 3,
     53 };
     54 
     55 class UnaryVariantOpRegistry {
     56  public:
     57   typedef std::function<Status(const Variant& v, TensorShape*)> VariantShapeFn;
     58   typedef std::function<bool(Variant*)> VariantDecodeFn;
     59   typedef std::function<Status(OpKernelContext*, const Variant&, Variant*)>
     60       VariantUnaryOpFn;
     61   typedef std::function<Status(OpKernelContext*, const Variant&, const Variant&,
     62                                Variant*)>
     63       VariantBinaryOpFn;
     64 
     65   // An AsyncTensorDeviceCopyFn is a function provided to
     66   // the user-provided DeviceCopyFn callback as the third argument ("copier").
     67   //
     68   // Expected inputs:
     69   //   from: A Tensor on the host (if performing cpu->gpu copy), or
     70   //         device (if performing gpu->cpu or gpu->gpu copy).
     71   //   to: An empty/uninitialized tensor.  It will be updated upon
     72   //       successful return of the function with the correct dtype and shape.
     73   //       However, the copied data will not be available until the compute
     74   //       stream has been synchronized.
     75   //
     76   // Returns:
     77   //   The status upon memory allocation / initialization of the
     78   //   "to" tensor, and enqueue of the copy onto the compute stream.
     79   //   Any failure of the copy itself will update the underlying
     80   //   stream status and propagate through the runtime independent
     81   //   of the caller.
     82   typedef std::function<Status(const Tensor& from, Tensor* to)>
     83       AsyncTensorDeviceCopyFn;
     84 
     85   // The AsyncVariantDeviceCopyFn is the signature of the 'device_copy_fn'
     86   // expected to be passed to the registration macro
     87   // INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION.
     88   typedef std::function<Status(const Variant& from, Variant* to,
     89                                AsyncTensorDeviceCopyFn copy_fn)>
     90       AsyncVariantDeviceCopyFn;
     91 
     92   // Add a shape lookup function to the registry.
     93   void RegisterShapeFn(const string& type_name, const VariantShapeFn& shape_fn);
     94 
     95   // Returns nullptr if no shape function was found for the given TypeName.
     96   VariantShapeFn* GetShapeFn(StringPiece type_name);
     97 
     98   // Add a decode function to the registry.
     99   void RegisterDecodeFn(const string& type_name,
    100                         const VariantDecodeFn& decode_fn);
    101 
    102   // Returns nullptr if no decode function was found for the given TypeName.
    103   VariantDecodeFn* GetDecodeFn(StringPiece type_name);
    104 
    105   // Add a copy-to-GPU function to the registry.
    106   void RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction,
    107                             const string& type_name,
    108                             const AsyncVariantDeviceCopyFn& device_copy_fn);
    109 
    110   // Returns nullptr if no copy function was found for the given
    111   // TypeName and direction.
    112   AsyncVariantDeviceCopyFn* GetDeviceCopyFn(
    113       const VariantDeviceCopyDirection direction, StringPiece type_name);
    114 
    115   // Add a unary op function to the registry.
    116   void RegisterUnaryOpFn(VariantUnaryOp op, const string& device,
    117                          const string& type_name,
    118                          const VariantUnaryOpFn& unary_op_fn);
    119 
    120   // Returns nullptr if no unary op function was found for the given
    121   // op, device, and TypeName.
    122   VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device,
    123                                  StringPiece type_name);
    124 
    125   // Add a binary op function to the registry.
    126   void RegisterBinaryOpFn(VariantBinaryOp op, const string& device,
    127                           const string& type_name,
    128                           const VariantBinaryOpFn& add_fn);
    129 
    130   // Returns nullptr if no binary op function was found for the given
    131   // op, device and TypeName.
    132   VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
    133                                    StringPiece type_name);
    134 
    135   // Get a pointer to a global UnaryVariantOpRegistry object
    136   static UnaryVariantOpRegistry* Global();
    137 
    138   // Get a pointer to a global persistent string storage object.
    139   // ISO/IEC C++ working draft N4296 clarifies that insertion into an
    140   // std::unordered_set does not invalidate memory locations of
    141   // *values* inside the set (though it may invalidate existing
    142   // iterators).  In other words, one may safely point a StringPiece to
    143   // a value in the set without that StringPiece being invalidated by
    144   // future insertions.
    145   static std::unordered_set<string>* PersistentStringStorage();
    146 
    147  private:
    148   std::unordered_map<StringPiece, VariantShapeFn, StringPieceHasher> shape_fns;
    149   std::unordered_map<StringPiece, VariantDecodeFn, StringPieceHasher>
    150       decode_fns;
    151 
    152   // Map std::pair<Direction, type_name> to function.
    153   struct PairHash {
    154     template <typename Direction>
    155     std::size_t operator()(const std::pair<Direction, StringPiece>& x) const {
    156       // The hash of an enum is just its value as a std::size_t.
    157       std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
    158       ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x)));
    159       return ret;
    160     }
    161     StringPieceHasher sp_hasher_;
    162   };
    163 
    164   std::unordered_map<std::pair<VariantDeviceCopyDirection, StringPiece>,
    165                      AsyncVariantDeviceCopyFn, PairHash>
    166       device_copy_fns;
    167 
    168   // Map std::tuple<Op, device, type_name> to function.
    169 
    170   // this breaks by falling victim to "too perfect forwarding"
    171   // see https://stackoverflow.com/questions/44475317/variadic-template-issue
    172   // and references therein
    173   template <typename Op>
    174   struct FuncTuple {
    175     FuncTuple(const Op& op, const StringPiece& dev, const StringPiece& tname)
    176         : op_type_(op), device_(dev), typename_(tname){};
    177     Op op_type_;
    178     StringPiece device_, typename_;
    179   };
    180   // friend declaration for operator==
    181   // needed for clang
    182   template <typename Op>
    183   friend bool operator==(const FuncTuple<Op>& l, const FuncTuple<Op>& r);
    184   struct TupleHash {
    185     template <typename Op>
    186     std::size_t operator()(
    187         const std::tuple<Op, StringPiece, StringPiece>& x) const {
    188       // The hash of an enum is just its value as a std::size_t.
    189       std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
    190       ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x)));
    191       ret = Hash64Combine(ret, sp_hasher_(std::get<2>(x)));
    192       return ret;
    193     }
    194 
    195     template <typename Op>
    196     std::size_t operator()(const FuncTuple<Op>& x) const {
    197       // The hash of an enum is just its value as a std::size_t.
    198       std::size_t ret = static_cast<std::size_t>(x.op_type_);
    199       ret = Hash64Combine(ret, sp_hasher_(x.device_));
    200       ret = Hash64Combine(ret, sp_hasher_(x.typename_));
    201       return ret;
    202     }
    203     StringPieceHasher sp_hasher_;
    204   };
    205   std::unordered_map<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash>
    206       unary_op_fns;
    207   std::unordered_map<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash>
    208       binary_op_fns;
    209 
    210   // Find or insert a string into a persistent string storage
    211   // container; return the StringPiece pointing to the permanent string
    212   // location.
    213   static StringPiece GetPersistentStringPiece(const string& str) {
    214     const auto string_storage = PersistentStringStorage();
    215     auto found = string_storage->find(str);
    216     if (found == string_storage->end()) {
    217       auto inserted = string_storage->insert(str);
    218       return StringPiece(*inserted.first);
    219     } else {
    220       return StringPiece(*found);
    221     }
    222   }
    223 };
    224 template <typename Op>
    225 inline bool operator==(const UnaryVariantOpRegistry::FuncTuple<Op>& lhs,
    226                        const UnaryVariantOpRegistry::FuncTuple<Op>& rhs) {
    227   return (lhs.op_type_ == rhs.op_type_) && (lhs.device_ == rhs.device_) &&
    228          (lhs.typename_ == rhs.typename_);
    229 }
    230 // Gets a TensorShape from a Tensor containing a scalar Variant.
    231 // Returns an Internal error if the Variant does not have a registered shape
    232 // function, or if it's a serialized Variant that cannot be decoded.
    233 //
    234 // REQUIRES:
    235 //   variant_tensor.dtype() == DT_VARIANT
    236 //   variant_tensor.dims() == 0
    237 //
    238 Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape);
    239 
    240 // Decodes the Variant whose data_type has a registered decode
    241 // function.  Returns an Internal error if the Variant does not have a
    242 // registered decode function, or if the decoding function fails.
    243 //
    244 // REQUIRES:
    245 //   variant is not null.
    246 //
    247 bool DecodeUnaryVariant(Variant* variant);
    248 
    249 // Copies a variant between CPU<->GPU, or between GPU<->GPU.
    250 // The variant 'from' must have a registered DeviceCopyFn for the
    251 // given direction.  The returned variant 'to' will have
    252 // (some subset of its) tensors stored on destination according to the
    253 // registered DeviceCopyFn function for the given direction.  Returns
    254 // an Internal error if the Variant does not have a registered
    255 // DeviceCopyFn function for the given direction, or if initiating the
    256 // copy fails.
    257 //
    258 // REQUIRES:
    259 //   'to' is not null.
    260 //
    261 Status VariantDeviceCopy(
    262     const VariantDeviceCopyDirection direction, const Variant& from,
    263     Variant* to,
    264     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn);
    265 
    266 // Sets *v_out = unary_op(v).  The variant v must have a registered
    267 // UnaryOp function for the given Device.  Returns an Internal error
    268 // if v does not have a registered unary_op function for this device, or if
    269 // UnaryOp fails.
    270 //
    271 // REQUIRES:
    272 //   v_out is not null.
    273 //
    274 template <typename Device>
    275 Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v,
    276                       Variant* v_out) {
    277   const string& device = DeviceName<Device>::value;
    278   UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn =
    279       UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeName());
    280   if (unary_op_fn == nullptr) {
    281     return errors::Internal(
    282         "No unary variant unary_op function found for unary variant op enum: ",
    283         op, " Variant type_name: ", v.TypeName(), " for device type: ", device);
    284   }
    285   return (*unary_op_fn)(ctx, v, v_out);
    286 }
    287 
    288 // Sets *out = binary_op(a, b).  The variants a and b must be the same type
    289 // and have a registered binary_op function for the given Device.  Returns an
    290 // Internal error if a and b are not the same type_name or if
    291 // if a does not have a registered op function for this device, or if
    292 // BinaryOp fails.
    293 //
    294 // REQUIRES:
    295 //   out is not null.
    296 //
    297 template <typename Device>
    298 Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op,
    299                         const Variant& a, const Variant& b, Variant* out) {
    300   if (a.TypeName() != b.TypeName()) {
    301     return errors::Internal(
    302         "BianryOpVariants: Variants a and b have different "
    303         "type names: '",
    304         a.TypeName(), "' vs. '", b.TypeName(), "'");
    305   }
    306   const string& device = DeviceName<Device>::value;
    307   UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn =
    308       UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeName());
    309   if (binary_op_fn == nullptr) {
    310     return errors::Internal(
    311         "No unary variant binary_op function found for binary variant op "
    312         "enum: ",
    313         op, " Variant type_name: '", a.TypeName(), "' for device type: ",
    314         device);
    315   }
    316   return (*binary_op_fn)(ctx, a, b, out);
    317 }
    318 
    319 namespace variant_op_registry_fn_registration {
    320 
    321 template <typename T>
    322 class UnaryVariantShapeRegistration {
    323  public:
    324   typedef std::function<Status(const T& t, TensorShape*)> LocalVariantShapeFn;
    325 
    326   UnaryVariantShapeRegistration(const string& type_name,
    327                                 const LocalVariantShapeFn& shape_fn) {
    328     UnaryVariantOpRegistry::Global()->RegisterShapeFn(
    329         type_name,
    330         [type_name, shape_fn](const Variant& v, TensorShape* s) -> Status {
    331           const T* t = v.get<T>();
    332           if (t == nullptr) {
    333             return errors::Internal(
    334                 "VariantShapeFn: Could not access object, type_name: ",
    335                 type_name);
    336           }
    337           return shape_fn(*t, s);
    338         });
    339   }
    340 };
    341 
    342 template <typename T>
    343 class UnaryVariantDecodeRegistration {
    344  public:
    345   UnaryVariantDecodeRegistration(const string& type_name) {
    346     // The Variant is passed by pointer because it should be
    347     // mutable: get below may Decode the variant, which
    348     // is a self-mutating behavior.  The variant is not modified in
    349     // any other way.
    350     UnaryVariantOpRegistry::Global()->RegisterDecodeFn(
    351         type_name, [type_name](Variant* v) -> bool {
    352           DCHECK_NE(v, nullptr);
    353           VariantTensorDataProto* t = v->get<VariantTensorDataProto>();
    354           if (t == nullptr) {
    355             return false;
    356           }
    357           Variant decoded = T();
    358           VariantTensorData data(*t);
    359           if (!decoded.Decode(data)) {
    360             return false;
    361           }
    362           *v = std::move(decoded);
    363           return true;
    364         });
    365   }
    366 };
    367 
    368 template <typename T>
    369 class UnaryVariantDeviceCopyRegistration {
    370  public:
    371   typedef std::function<Status(const T& t, T* t_out,
    372                                UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn)>
    373       LocalVariantDeviceCopyFn;
    374   UnaryVariantDeviceCopyRegistration(
    375       const VariantDeviceCopyDirection direction, const string& type_name,
    376       const LocalVariantDeviceCopyFn& device_copy_fn) {
    377     UnaryVariantOpRegistry::Global()->RegisterDeviceCopyFn(
    378         direction, type_name,
    379         [type_name, device_copy_fn](
    380             const Variant& from, Variant* to,
    381             UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn
    382                 device_copy_tensor_fn) -> Status {
    383           DCHECK_NE(to, nullptr);
    384           *to = T();
    385           if (from.get<T>() == nullptr) {
    386             return errors::Internal(
    387                 "VariantCopyToGPUFn: Could not access object, type_name: ",
    388                 type_name);
    389           }
    390           const T& t = *from.get<T>();
    391           T* t_out = to->get<T>();
    392           return device_copy_fn(t, t_out, device_copy_tensor_fn);
    393         });
    394   }
    395 };
    396 
    397 template <typename T>
    398 class UnaryVariantUnaryOpRegistration {
    399   typedef std::function<Status(OpKernelContext* ctx, const T& t, T* t_out)>
    400       LocalVariantUnaryOpFn;
    401 
    402  public:
    403   UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const string& device,
    404                                   const string& type_name,
    405                                   const LocalVariantUnaryOpFn& unary_op_fn) {
    406     UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn(
    407         op, device, type_name,
    408         [type_name, unary_op_fn](OpKernelContext* ctx, const Variant& v,
    409                                  Variant* v_out) -> Status {
    410           DCHECK_NE(v_out, nullptr);
    411           *v_out = T();
    412           if (v.get<T>() == nullptr) {
    413             return errors::Internal(
    414                 "VariantUnaryOpFn: Could not access object, type_name: ",
    415                 type_name);
    416           }
    417           const T& t = *v.get<T>();
    418           T* t_out = v_out->get<T>();
    419           return unary_op_fn(ctx, t, t_out);
    420         });
    421   }
    422 };
    423 
    424 template <typename T>
    425 class UnaryVariantBinaryOpRegistration {
    426   typedef std::function<Status(OpKernelContext* ctx, const T& a, const T& b,
    427                                T* out)>
    428       LocalVariantBinaryOpFn;
    429 
    430  public:
    431   UnaryVariantBinaryOpRegistration(VariantBinaryOp op, const string& device,
    432                                    const string& type_name,
    433                                    const LocalVariantBinaryOpFn& binary_op_fn) {
    434     UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn(
    435         op, device, type_name,
    436         [type_name, binary_op_fn](OpKernelContext* ctx, const Variant& a,
    437                                   const Variant& b, Variant* out) -> Status {
    438           DCHECK_NE(out, nullptr);
    439           *out = T();
    440           if (a.get<T>() == nullptr) {
    441             return errors::Internal(
    442                 "VariantBinaryOpFn: Could not access object 'a', type_name: ",
    443                 type_name);
    444           }
    445           if (b.get<T>() == nullptr) {
    446             return errors::Internal(
    447                 "VariantBinaryOpFn: Could not access object 'b', type_name: ",
    448                 type_name);
    449           }
    450           const T& t_a = *a.get<T>();
    451           const T& t_b = *b.get<T>();
    452           T* t_out = out->get<T>();
    453           return binary_op_fn(ctx, t_a, t_b, t_out);
    454         });
    455   }
    456 };
    457 
    458 };  // namespace variant_op_registry_fn_registration
    459 
    460 // Register a unary shape variant function with the signature:
    461 //    Status ShapeFn(const T& t, TensorShape* s);
    462 // to Variants having TypeName type_name.
    463 #define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, type_name, shape_function)    \
    464   REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(__COUNTER__, T, type_name, \
    465                                                     shape_function)
    466 
    467 #define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(ctr, T, type_name, \
    468                                                           shape_function)    \
    469   REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_name, shape_function)
    470 
    471 #define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_name,          \
    472                                                    shape_function)             \
    473   static variant_op_registry_fn_registration::UnaryVariantShapeRegistration<T> \
    474       register_unary_variant_op_shape_registration_fn_##ctr(type_name,         \
    475                                                             shape_function)
    476 
    477 // Register a unary decode variant function for the given type.
    478 #define REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, type_name) \
    479   REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ_HELPER(__COUNTER__, T, type_name)
    480 
    481 #define REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ_HELPER(ctr, T, type_name) \
    482   REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ(ctr, T, type_name)
    483 
    484 #define REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ(ctr, T, type_name)        \
    485   static variant_op_registry_fn_registration::UnaryVariantDecodeRegistration< \
    486       T>                                                                      \
    487       register_unary_variant_op_decoder_fn_##ctr(type_name)
    488 
    489 // ****** NOTE ******
    490 // FOR INTERNAL USE ONLY.  IF YOU USE THIS WE MAY BREAK YOUR CODE.
    491 // ****** NOTE ******
    492 //
    493 // Register a device copy variant function for the given copy
    494 // direction and type; where direction is the enum
    495 // VariantDeviceCopyDirection, and the device_copy_fn has signature:
    496 //
    497 //   Status device_copy_fn(
    498 //     const T& t, T* t_out,
    499 //     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copier);
    500 //
    501 // And device_copy_fn calls copier 0 or more times.  For details on
    502 // the behavior of the copier function, see the comments at the
    503 // declaration of UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn.
    504 //
    505 // Note, the device_copy_fn may choose to keep some tensors
    506 // on host, e.g. by assigning to->tensor = from.tensor (assuming
    507 // from.tensor is already on host); or by setting
    508 //   to->tensor = Tensor(cpu_allocator(), ...)
    509 // and manually updating its values.
    510 //
    511 // If this is the case, the CopyFns for HOST_TO_DEVICE,
    512 // DEVICE_TO_HOST, and DEVICE_TO_DEVICE must perform host-to-host
    513 // copies in a consistent manner.  For example, one must always
    514 // manually copy any "always on host" tensors in all directions instead of e.g.
    515 //   - performing a host-to-host copy in one direction,
    516 //   - using the provided copier function in the reverse direction.
    517 // Doing the latter will cause program failures.
    518 //
    519 // ****** NOTE ******
    520 // FOR INTERNAL USE ONLY.  IF YOU USE THIS WE MAY BREAK YOUR CODE.
    521 // ****** NOTE ******
    522 #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(       \
    523     T, direction, type_name, device_copy_fn)                        \
    524   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
    525       __COUNTER__, T, direction, type_name, device_copy_fn)
    526 
    527 #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
    528     ctr, T, direction, type_name, device_copy_fn)                         \
    529   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ(              \
    530       ctr, T, direction, type_name, device_copy_fn)
    531 
    532 #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ(             \
    533     ctr, T, direction, type_name, device_copy_fn)                              \
    534   static variant_op_registry_fn_registration::                                 \
    535       UnaryVariantDeviceCopyRegistration<T>                                    \
    536           register_unary_variant_op_device_copy_fn_##ctr(direction, type_name, \
    537                                                          device_copy_fn)
    538 
    539 // Register a unary unary_op variant function with the signature:
    540 //    Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out);
    541 // to Variants having TypeName type_name, for device string device,
    542 // for UnaryVariantOp enum op.
    543 #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, type_name, \
    544                                                  unary_op_function)        \
    545   REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER(                    \
    546       __COUNTER__, op, device, T, type_name, unary_op_function)
    547 
    548 #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER(                  \
    549     ctr, op, device, T, type_name, unary_op_function)                          \
    550   REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, type_name, \
    551                                                 unary_op_function)
    552 
    553 #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(                         \
    554     ctr, op, device, T, type_name, unary_op_function)                          \
    555   static variant_op_registry_fn_registration::UnaryVariantUnaryOpRegistration< \
    556       T>                                                                       \
    557       register_unary_variant_op_decoder_fn_##ctr(op, device, type_name,        \
    558                                                  unary_op_function)
    559 
    560 // Register a binary_op variant function with the signature:
    561 //    Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out);
    562 // to Variants having TypeName type_name, for device string device,
    563 // for BinaryVariantOp enum OP.
    564 #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, type_name, \
    565                                                   binary_op_function)       \
    566   REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER(                    \
    567       __COUNTER__, op, device, T, type_name, binary_op_function)
    568 
    569 #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
    570     ctr, op, device, T, type_name, binary_op_function)         \
    571   REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ(              \
    572       ctr, op, device, T, type_name, binary_op_function)
    573 
    574 #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ(                     \
    575     ctr, op, device, T, type_name, binary_op_function)                      \
    576   static variant_op_registry_fn_registration::                              \
    577       UnaryVariantBinaryOpRegistration<T>                                   \
    578           register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \
    579                                                      binary_op_function)
    580 
    581 }  // end namespace tensorflow
    582 
    583 #endif  // TENSORFLOW_FRAMEWORK_VARIANT_OP_REGISTRY_H_
    584