Home | History | Annotate | Download | only in framework
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <string>
     17 
     18 #include "tensorflow/core/framework/register_types.h"
     19 #include "tensorflow/core/framework/type_index.h"
     20 #include "tensorflow/core/framework/variant.h"
     21 #include "tensorflow/core/framework/variant_op_registry.h"
     22 #include "tensorflow/core/lib/core/errors.h"
     23 #include "tensorflow/core/public/version.h"
     24 
     25 namespace tensorflow {
     26 
     27 std::unordered_set<string>* UnaryVariantOpRegistry::PersistentStringStorage() {
     28   static std::unordered_set<string>* string_storage =
     29       new std::unordered_set<string>();
     30   return string_storage;
     31 }
     32 
     33 // static
     34 UnaryVariantOpRegistry* UnaryVariantOpRegistry::Global() {
     35   static UnaryVariantOpRegistry* global_unary_variant_op_registry =
     36       new UnaryVariantOpRegistry;
     37   return global_unary_variant_op_registry;
     38 }
     39 
     40 UnaryVariantOpRegistry::VariantShapeFn* UnaryVariantOpRegistry::GetShapeFn(
     41     StringPiece type_name) {
     42   auto found = shape_fns.find(type_name);
     43   if (found == shape_fns.end()) return nullptr;
     44   return &found->second;
     45 }
     46 
     47 void UnaryVariantOpRegistry::RegisterShapeFn(const string& type_name,
     48                                              const VariantShapeFn& shape_fn) {
     49   CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantShape";
     50   VariantShapeFn* existing = GetShapeFn(type_name);
     51   CHECK_EQ(existing, nullptr)
     52       << "Unary VariantShapeFn for type_name: " << type_name
     53       << " already registered";
     54   shape_fns.insert(std::pair<StringPiece, VariantShapeFn>(
     55       GetPersistentStringPiece(type_name), shape_fn));
     56 }
     57 
     58 Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) {
     59   CHECK_EQ(variant_tensor.dtype(), DT_VARIANT);
     60   CHECK_EQ(variant_tensor.dims(), 0);
     61   const Variant& v = variant_tensor.scalar<Variant>()();
     62   UnaryVariantOpRegistry::VariantShapeFn* shape_fn =
     63       UnaryVariantOpRegistry::Global()->GetShapeFn(v.TypeName());
     64   if (shape_fn == nullptr) {
     65     return errors::Internal(
     66         "No unary variant shape function found for Variant type_name: ",
     67         v.TypeName());
     68   }
     69   return (*shape_fn)(v, shape);
     70 }
     71 
     72 // Add some basic registrations for use by others, e.g., for testing.
     73 namespace {
     74 template <typename T>
     75 Status ScalarShape(const T&, TensorShape* shape) {
     76   *shape = TensorShape({});
     77   return Status::OK();
     78 }
     79 }  // namespace
     80 
     81 #define REGISTER_VARIANT_SHAPE_TYPE(T) \
     82   REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, TF_STR(T), ScalarShape<T>);
     83 
     84 // No encode/shape registered for std::complex<> and Eigen::half
     85 // objects yet.
     86 REGISTER_VARIANT_SHAPE_TYPE(int);
     87 REGISTER_VARIANT_SHAPE_TYPE(float);
     88 REGISTER_VARIANT_SHAPE_TYPE(bool);
     89 REGISTER_VARIANT_SHAPE_TYPE(double);
     90 
     91 #undef REGISTER_VARIANT_SHAPE_TYPE
     92 
     93 UnaryVariantOpRegistry::VariantDecodeFn* UnaryVariantOpRegistry::GetDecodeFn(
     94     StringPiece type_name) {
     95   auto found = decode_fns.find(type_name);
     96   if (found == decode_fns.end()) return nullptr;
     97   return &found->second;
     98 }
     99 
    100 void UnaryVariantOpRegistry::RegisterDecodeFn(
    101     const string& type_name, const VariantDecodeFn& decode_fn) {
    102   CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantDecode";
    103   VariantDecodeFn* existing = GetDecodeFn(type_name);
    104   CHECK_EQ(existing, nullptr)
    105       << "Unary VariantDecodeFn for type_name: " << type_name
    106       << " already registered";
    107   decode_fns.insert(std::pair<StringPiece, VariantDecodeFn>(
    108       GetPersistentStringPiece(type_name), decode_fn));
    109 }
    110 
    111 bool DecodeUnaryVariant(Variant* variant) {
    112   UnaryVariantOpRegistry::VariantDecodeFn* decode_fn =
    113       UnaryVariantOpRegistry::Global()->GetDecodeFn(variant->TypeName());
    114   if (decode_fn == nullptr) {
    115     return false;
    116   }
    117   const string type_name = variant->TypeName();
    118   bool decoded = (*decode_fn)(variant);
    119   if (!decoded) return false;
    120   if (variant->TypeName() != type_name) {
    121     LOG(ERROR) << "DecodeUnaryVariant: Variant type_name before decoding was: "
    122                << type_name
    123                << " but after decoding was: " << variant->TypeName()
    124                << ".  Treating this as a failure.";
    125     return false;
    126   }
    127   return true;
    128 }
    129 
    130 // Add some basic registrations for use by others, e.g., for testing.
    131 
    132 #define REGISTER_VARIANT_DECODE_TYPE(T) \
    133   REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, TF_STR(T));
    134 
    135 // No encode/decode registered for std::complex<> and Eigen::half
    136 // objects yet.
    137 REGISTER_VARIANT_DECODE_TYPE(int);
    138 REGISTER_VARIANT_DECODE_TYPE(float);
    139 REGISTER_VARIANT_DECODE_TYPE(bool);
    140 REGISTER_VARIANT_DECODE_TYPE(double);
    141 
    142 #undef REGISTER_VARIANT_DECODE_TYPE
    143 
    144 UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn*
    145 UnaryVariantOpRegistry::GetDeviceCopyFn(
    146     const VariantDeviceCopyDirection direction, StringPiece type_name) {
    147   auto found = device_copy_fns.find(std::make_pair(direction, type_name));
    148   if (found == device_copy_fns.end()) return nullptr;
    149   return &found->second;
    150 }
    151 
    152 void UnaryVariantOpRegistry::RegisterDeviceCopyFn(
    153     const VariantDeviceCopyDirection direction, const string& type_name,
    154     const AsyncVariantDeviceCopyFn& device_copy_fn) {
    155   CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantDeviceCopy";
    156   AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_name);
    157   CHECK_EQ(existing, nullptr)
    158       << "UnaryVariantDeviceCopy for direction: " << direction
    159       << " and type_name: " << type_name << " already registered";
    160   device_copy_fns.insert(
    161       std::pair<std::pair<VariantDeviceCopyDirection, StringPiece>,
    162                 AsyncVariantDeviceCopyFn>(
    163           std::make_pair(direction, GetPersistentStringPiece(type_name)),
    164           device_copy_fn));
    165 }
    166 
    167 Status VariantDeviceCopy(
    168     const VariantDeviceCopyDirection direction, const Variant& from,
    169     Variant* to,
    170     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn) {
    171   UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* device_copy_fn =
    172       UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(direction,
    173                                                         from.TypeName());
    174   if (device_copy_fn == nullptr) {
    175     return errors::Internal(
    176         "No unary variant device copy function found for direction: ",
    177         direction, " and Variant type_name: ", from.TypeName());
    178   }
    179   return (*device_copy_fn)(from, to, copy_fn);
    180 }
    181 
    182 // Special casing UnaryOpFn per op and per device.
    183 UnaryVariantOpRegistry::VariantUnaryOpFn* UnaryVariantOpRegistry::GetUnaryOpFn(
    184     VariantUnaryOp op, StringPiece device, StringPiece type_name) {
    185   auto found = unary_op_fns.find({op, device, type_name});
    186   if (found == unary_op_fns.end()) return nullptr;
    187   return &found->second;
    188 }
    189 
    190 void UnaryVariantOpRegistry::RegisterUnaryOpFn(
    191     VariantUnaryOp op, const string& device, const string& type_name,
    192     const VariantUnaryOpFn& unary_op_fn) {
    193   CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantUnaryOp";
    194   VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_name);
    195   CHECK_EQ(existing, nullptr)
    196       << "Unary VariantUnaryOpFn for type_name: " << type_name
    197       << " already registered for device type: " << device;
    198   unary_op_fns.insert(std::pair<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn>(
    199       {op, GetPersistentStringPiece(device),
    200        GetPersistentStringPiece(type_name)},
    201       unary_op_fn));
    202 }
    203 
    204 namespace {
    205 template <typename T>
    206 Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t,
    207                                      T* t_out) {
    208   *t_out = T(0);
    209   return Status::OK();
    210 }
    211 }  // namespace
    212 
    213 #define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T)                             \
    214   REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, \
    215                                            DEVICE_CPU, T, TF_STR(T),    \
    216                                            ZerosLikeVariantPrimitiveType<T>);
    217 
    218 // No zeros_like registered for std::complex<> or Eigen::half objects yet.
    219 REGISTER_VARIANT_ZEROS_LIKE_TYPE(int);
    220 REGISTER_VARIANT_ZEROS_LIKE_TYPE(float);
    221 REGISTER_VARIANT_ZEROS_LIKE_TYPE(double);
    222 REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool);
    223 
    224 #undef REGISTER_VARIANT_ZEROS_LIKE_TYPE
    225 
    226 // Special casing BinaryOpFn per op and per device.
    227 UnaryVariantOpRegistry::VariantBinaryOpFn*
    228 UnaryVariantOpRegistry::GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
    229                                       StringPiece type_name) {
    230   auto found = binary_op_fns.find({op, device, type_name});
    231   if (found == binary_op_fns.end()) return nullptr;
    232   return &found->second;
    233 }
    234 
    235 void UnaryVariantOpRegistry::RegisterBinaryOpFn(
    236     VariantBinaryOp op, const string& device, const string& type_name,
    237     const VariantBinaryOpFn& add_fn) {
    238   CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantBinaryOp";
    239   VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_name);
    240   CHECK_EQ(existing, nullptr)
    241       << "Unary VariantBinaryOpFn for type_name: " << type_name
    242       << " already registered for device type: " << device;
    243   binary_op_fns.insert(std::pair<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn>(
    244       {op, GetPersistentStringPiece(device),
    245        GetPersistentStringPiece(type_name)},
    246       add_fn));
    247 }
    248 
    249 namespace {
    250 template <typename T>
    251 Status AddVariantPrimitiveType(OpKernelContext* ctx, const T& a, const T& b,
    252                                T* out) {
    253   *out = a + b;
    254   return Status::OK();
    255 }
    256 }  // namespace
    257 
    258 #define REGISTER_VARIANT_ADD_TYPE(T)                                           \
    259   REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, \
    260                                             T, TF_STR(T),                      \
    261                                             AddVariantPrimitiveType<T>);
    262 
    263 // No add registered for std::complex<> or Eigen::half objects yet.
    264 REGISTER_VARIANT_ADD_TYPE(int);
    265 REGISTER_VARIANT_ADD_TYPE(float);
    266 REGISTER_VARIANT_ADD_TYPE(double);
    267 REGISTER_VARIANT_ADD_TYPE(bool);
    268 
    269 #undef REGISTER_VARIANT_ADD_TYPE
    270 
    271 }  // namespace tensorflow
    272