Home | History | Annotate | Download | only in Orc
      1 //===- llvm/ExecutionEngine/Orc/RPCSerialization.h --------------*- C++ -*-===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 
     10 #ifndef LLVM_EXECUTIONENGINE_ORC_RPCSERIALIZATION_H
     11 #define LLVM_EXECUTIONENGINE_ORC_RPCSERIALIZATION_H
     12 
     13 #include "OrcError.h"
     14 #include "llvm/Support/thread.h"
     15 #include <map>
     16 #include <mutex>
     17 #include <sstream>
     18 
     19 namespace llvm {
     20 namespace orc {
     21 namespace rpc {
     22 
     23 template <typename T>
     24 class RPCTypeName;
     25 
     26 /// TypeNameSequence is a utility for rendering sequences of types to a string
     27 /// by rendering each type, separated by ", ".
     28 template <typename... ArgTs> class RPCTypeNameSequence {};
     29 
     30 /// Render an empty TypeNameSequence to an ostream.
     31 template <typename OStream>
     32 OStream &operator<<(OStream &OS, const RPCTypeNameSequence<> &V) {
     33   return OS;
     34 }
     35 
     36 /// Render a TypeNameSequence of a single type to an ostream.
     37 template <typename OStream, typename ArgT>
     38 OStream &operator<<(OStream &OS, const RPCTypeNameSequence<ArgT> &V) {
     39   OS << RPCTypeName<ArgT>::getName();
     40   return OS;
     41 }
     42 
     43 /// Render a TypeNameSequence of more than one type to an ostream.
     44 template <typename OStream, typename ArgT1, typename ArgT2, typename... ArgTs>
     45 OStream&
     46 operator<<(OStream &OS, const RPCTypeNameSequence<ArgT1, ArgT2, ArgTs...> &V) {
     47   OS << RPCTypeName<ArgT1>::getName() << ", "
     48      << RPCTypeNameSequence<ArgT2, ArgTs...>();
     49   return OS;
     50 }
     51 
     52 template <>
     53 class RPCTypeName<void> {
     54 public:
     55   static const char* getName() { return "void"; }
     56 };
     57 
     58 template <>
     59 class RPCTypeName<int8_t> {
     60 public:
     61   static const char* getName() { return "int8_t"; }
     62 };
     63 
     64 template <>
     65 class RPCTypeName<uint8_t> {
     66 public:
     67   static const char* getName() { return "uint8_t"; }
     68 };
     69 
     70 template <>
     71 class RPCTypeName<int16_t> {
     72 public:
     73   static const char* getName() { return "int16_t"; }
     74 };
     75 
     76 template <>
     77 class RPCTypeName<uint16_t> {
     78 public:
     79   static const char* getName() { return "uint16_t"; }
     80 };
     81 
     82 template <>
     83 class RPCTypeName<int32_t> {
     84 public:
     85   static const char* getName() { return "int32_t"; }
     86 };
     87 
     88 template <>
     89 class RPCTypeName<uint32_t> {
     90 public:
     91   static const char* getName() { return "uint32_t"; }
     92 };
     93 
     94 template <>
     95 class RPCTypeName<int64_t> {
     96 public:
     97   static const char* getName() { return "int64_t"; }
     98 };
     99 
    100 template <>
    101 class RPCTypeName<uint64_t> {
    102 public:
    103   static const char* getName() { return "uint64_t"; }
    104 };
    105 
    106 template <>
    107 class RPCTypeName<bool> {
    108 public:
    109   static const char* getName() { return "bool"; }
    110 };
    111 
    112 template <>
    113 class RPCTypeName<std::string> {
    114 public:
    115   static const char* getName() { return "std::string"; }
    116 };
    117 
    118 template <>
    119 class RPCTypeName<Error> {
    120 public:
    121   static const char* getName() { return "Error"; }
    122 };
    123 
    124 template <typename T>
    125 class RPCTypeName<Expected<T>> {
    126 public:
    127   static const char* getName() {
    128     std::lock_guard<std::mutex> Lock(NameMutex);
    129     if (Name.empty())
    130       raw_string_ostream(Name) << "Expected<"
    131                                << RPCTypeNameSequence<T>()
    132                                << ">";
    133     return Name.data();
    134   }
    135 
    136 private:
    137   static std::mutex NameMutex;
    138   static std::string Name;
    139 };
    140 
    141 template <typename T>
    142 std::mutex RPCTypeName<Expected<T>>::NameMutex;
    143 
    144 template <typename T>
    145 std::string RPCTypeName<Expected<T>>::Name;
    146 
    147 template <typename T1, typename T2>
    148 class RPCTypeName<std::pair<T1, T2>> {
    149 public:
    150   static const char* getName() {
    151     std::lock_guard<std::mutex> Lock(NameMutex);
    152     if (Name.empty())
    153       raw_string_ostream(Name) << "std::pair<" << RPCTypeNameSequence<T1, T2>()
    154                                << ">";
    155     return Name.data();
    156   }
    157 private:
    158   static std::mutex NameMutex;
    159   static std::string Name;
    160 };
    161 
    162 template <typename T1, typename T2>
    163 std::mutex RPCTypeName<std::pair<T1, T2>>::NameMutex;
    164 template <typename T1, typename T2>
    165 std::string RPCTypeName<std::pair<T1, T2>>::Name;
    166 
    167 template <typename... ArgTs>
    168 class RPCTypeName<std::tuple<ArgTs...>> {
    169 public:
    170   static const char* getName() {
    171     std::lock_guard<std::mutex> Lock(NameMutex);
    172     if (Name.empty())
    173       raw_string_ostream(Name) << "std::tuple<"
    174                                << RPCTypeNameSequence<ArgTs...>() << ">";
    175     return Name.data();
    176   }
    177 private:
    178   static std::mutex NameMutex;
    179   static std::string Name;
    180 };
    181 
    182 template <typename... ArgTs>
    183 std::mutex RPCTypeName<std::tuple<ArgTs...>>::NameMutex;
    184 template <typename... ArgTs>
    185 std::string RPCTypeName<std::tuple<ArgTs...>>::Name;
    186 
    187 template <typename T>
    188 class RPCTypeName<std::vector<T>> {
    189 public:
    190   static const char*getName() {
    191     std::lock_guard<std::mutex> Lock(NameMutex);
    192     if (Name.empty())
    193       raw_string_ostream(Name) << "std::vector<" << RPCTypeName<T>::getName()
    194                                << ">";
    195     return Name.data();
    196   }
    197 
    198 private:
    199   static std::mutex NameMutex;
    200   static std::string Name;
    201 };
    202 
    203 template <typename T>
    204 std::mutex RPCTypeName<std::vector<T>>::NameMutex;
    205 template <typename T>
    206 std::string RPCTypeName<std::vector<T>>::Name;
    207 
    208 
    209 /// The SerializationTraits<ChannelT, T> class describes how to serialize and
    210 /// deserialize an instance of type T to/from an abstract channel of type
    211 /// ChannelT. It also provides a representation of the type's name via the
    212 /// getName method.
    213 ///
    214 /// Specializations of this class should provide the following functions:
    215 ///
    216 ///   @code{.cpp}
    217 ///
    218 ///   static const char* getName();
    219 ///   static Error serialize(ChannelT&, const T&);
    220 ///   static Error deserialize(ChannelT&, T&);
    221 ///
    222 ///   @endcode
    223 ///
    224 /// The third argument of SerializationTraits is intended to support SFINAE.
    225 /// E.g.:
    226 ///
    227 ///   @code{.cpp}
    228 ///
    229 ///   class MyVirtualChannel { ... };
    230 ///
    231 ///   template <DerivedChannelT>
    232 ///   class SerializationTraits<DerivedChannelT, bool,
    233 ///         typename std::enable_if<
    234 ///           std::is_base_of<VirtChannel, DerivedChannel>::value
    235 ///         >::type> {
    236 ///   public:
    237 ///     static const char* getName() { ... };
    238 ///   }
    239 ///
    240 ///   @endcode
    241 template <typename ChannelT, typename WireType,
    242           typename ConcreteType = WireType, typename = void>
    243 class SerializationTraits;
    244 
    245 template <typename ChannelT>
    246 class SequenceTraits {
    247 public:
    248   static Error emitSeparator(ChannelT &C) { return Error::success(); }
    249   static Error consumeSeparator(ChannelT &C) { return Error::success(); }
    250 };
    251 
    252 /// Utility class for serializing sequences of values of varying types.
    253 /// Specializations of this class contain 'serialize' and 'deserialize' methods
    254 /// for the given channel. The ArgTs... list will determine the "over-the-wire"
    255 /// types to be serialized. The serialize and deserialize methods take a list
    256 /// CArgTs... ("caller arg types") which must be the same length as ArgTs...,
    257 /// but may be different types from ArgTs, provided that for each CArgT there
    258 /// is a SerializationTraits specialization
    259 /// SerializeTraits<ChannelT, ArgT, CArgT> with methods that can serialize the
    260 /// caller argument to over-the-wire value.
    261 template <typename ChannelT, typename... ArgTs>
    262 class SequenceSerialization;
    263 
    264 template <typename ChannelT>
    265 class SequenceSerialization<ChannelT> {
    266 public:
    267   static Error serialize(ChannelT &C) { return Error::success(); }
    268   static Error deserialize(ChannelT &C) { return Error::success(); }
    269 };
    270 
    271 template <typename ChannelT, typename ArgT>
    272 class SequenceSerialization<ChannelT, ArgT> {
    273 public:
    274 
    275   template <typename CArgT>
    276   static Error serialize(ChannelT &C, CArgT &&CArg) {
    277     return SerializationTraits<ChannelT, ArgT,
    278                                typename std::decay<CArgT>::type>::
    279              serialize(C, std::forward<CArgT>(CArg));
    280   }
    281 
    282   template <typename CArgT>
    283   static Error deserialize(ChannelT &C, CArgT &CArg) {
    284     return SerializationTraits<ChannelT, ArgT, CArgT>::deserialize(C, CArg);
    285   }
    286 };
    287 
    288 template <typename ChannelT, typename ArgT, typename... ArgTs>
    289 class SequenceSerialization<ChannelT, ArgT, ArgTs...> {
    290 public:
    291 
    292   template <typename CArgT, typename... CArgTs>
    293   static Error serialize(ChannelT &C, CArgT &&CArg,
    294                          CArgTs &&... CArgs) {
    295     if (auto Err =
    296         SerializationTraits<ChannelT, ArgT, typename std::decay<CArgT>::type>::
    297           serialize(C, std::forward<CArgT>(CArg)))
    298       return Err;
    299     if (auto Err = SequenceTraits<ChannelT>::emitSeparator(C))
    300       return Err;
    301     return SequenceSerialization<ChannelT, ArgTs...>::
    302              serialize(C, std::forward<CArgTs>(CArgs)...);
    303   }
    304 
    305   template <typename CArgT, typename... CArgTs>
    306   static Error deserialize(ChannelT &C, CArgT &CArg,
    307                            CArgTs &... CArgs) {
    308     if (auto Err =
    309         SerializationTraits<ChannelT, ArgT, CArgT>::deserialize(C, CArg))
    310       return Err;
    311     if (auto Err = SequenceTraits<ChannelT>::consumeSeparator(C))
    312       return Err;
    313     return SequenceSerialization<ChannelT, ArgTs...>::deserialize(C, CArgs...);
    314   }
    315 };
    316 
    317 template <typename ChannelT, typename... ArgTs>
    318 Error serializeSeq(ChannelT &C, ArgTs &&... Args) {
    319   return SequenceSerialization<ChannelT, typename std::decay<ArgTs>::type...>::
    320            serialize(C, std::forward<ArgTs>(Args)...);
    321 }
    322 
    323 template <typename ChannelT, typename... ArgTs>
    324 Error deserializeSeq(ChannelT &C, ArgTs &... Args) {
    325   return SequenceSerialization<ChannelT, ArgTs...>::deserialize(C, Args...);
    326 }
    327 
    328 template <typename ChannelT>
    329 class SerializationTraits<ChannelT, Error> {
    330 public:
    331 
    332   using WrappedErrorSerializer =
    333     std::function<Error(ChannelT &C, const ErrorInfoBase&)>;
    334 
    335   using WrappedErrorDeserializer =
    336     std::function<Error(ChannelT &C, Error &Err)>;
    337 
    338   template <typename ErrorInfoT, typename SerializeFtor,
    339             typename DeserializeFtor>
    340   static void registerErrorType(std::string Name, SerializeFtor Serialize,
    341                                 DeserializeFtor Deserialize) {
    342     assert(!Name.empty() &&
    343            "The empty string is reserved for the Success value");
    344 
    345     const std::string *KeyName = nullptr;
    346     {
    347       // We're abusing the stability of std::map here: We take a reference to the
    348       // key of the deserializers map to save us from duplicating the string in
    349       // the serializer. This should be changed to use a stringpool if we switch
    350       // to a map type that may move keys in memory.
    351       std::lock_guard<std::recursive_mutex> Lock(DeserializersMutex);
    352       auto I =
    353         Deserializers.insert(Deserializers.begin(),
    354                              std::make_pair(std::move(Name),
    355                                             std::move(Deserialize)));
    356       KeyName = &I->first;
    357     }
    358 
    359     {
    360       assert(KeyName != nullptr && "No keyname pointer");
    361       std::lock_guard<std::recursive_mutex> Lock(SerializersMutex);
    362       // FIXME: Move capture Serialize once we have C++14.
    363       Serializers[ErrorInfoT::classID()] =
    364           [KeyName, Serialize](ChannelT &C, const ErrorInfoBase &EIB) -> Error {
    365         assert(EIB.dynamicClassID() == ErrorInfoT::classID() &&
    366                "Serializer called for wrong error type");
    367         if (auto Err = serializeSeq(C, *KeyName))
    368           return Err;
    369         return Serialize(C, static_cast<const ErrorInfoT &>(EIB));
    370       };
    371     }
    372   }
    373 
    374   static Error serialize(ChannelT &C, Error &&Err) {
    375     std::lock_guard<std::recursive_mutex> Lock(SerializersMutex);
    376 
    377     if (!Err)
    378       return serializeSeq(C, std::string());
    379 
    380     return handleErrors(std::move(Err),
    381                         [&C](const ErrorInfoBase &EIB) {
    382                           auto SI = Serializers.find(EIB.dynamicClassID());
    383                           if (SI == Serializers.end())
    384                             return serializeAsStringError(C, EIB);
    385                           return (SI->second)(C, EIB);
    386                         });
    387   }
    388 
    389   static Error deserialize(ChannelT &C, Error &Err) {
    390     std::lock_guard<std::recursive_mutex> Lock(DeserializersMutex);
    391 
    392     std::string Key;
    393     if (auto Err = deserializeSeq(C, Key))
    394       return Err;
    395 
    396     if (Key.empty()) {
    397       ErrorAsOutParameter EAO(&Err);
    398       Err = Error::success();
    399       return Error::success();
    400     }
    401 
    402     auto DI = Deserializers.find(Key);
    403     assert(DI != Deserializers.end() && "No deserializer for error type");
    404     return (DI->second)(C, Err);
    405   }
    406 
    407 private:
    408 
    409   static Error serializeAsStringError(ChannelT &C, const ErrorInfoBase &EIB) {
    410     std::string ErrMsg;
    411     {
    412       raw_string_ostream ErrMsgStream(ErrMsg);
    413       EIB.log(ErrMsgStream);
    414     }
    415     return serialize(C, make_error<StringError>(std::move(ErrMsg),
    416                                                 inconvertibleErrorCode()));
    417   }
    418 
    419   static std::recursive_mutex SerializersMutex;
    420   static std::recursive_mutex DeserializersMutex;
    421   static std::map<const void*, WrappedErrorSerializer> Serializers;
    422   static std::map<std::string, WrappedErrorDeserializer> Deserializers;
    423 };
    424 
    425 template <typename ChannelT>
    426 std::recursive_mutex SerializationTraits<ChannelT, Error>::SerializersMutex;
    427 
    428 template <typename ChannelT>
    429 std::recursive_mutex SerializationTraits<ChannelT, Error>::DeserializersMutex;
    430 
    431 template <typename ChannelT>
    432 std::map<const void*,
    433          typename SerializationTraits<ChannelT, Error>::WrappedErrorSerializer>
    434 SerializationTraits<ChannelT, Error>::Serializers;
    435 
    436 template <typename ChannelT>
    437 std::map<std::string,
    438          typename SerializationTraits<ChannelT, Error>::WrappedErrorDeserializer>
    439 SerializationTraits<ChannelT, Error>::Deserializers;
    440 
    441 /// Registers a serializer and deserializer for the given error type on the
    442 /// given channel type.
    443 template <typename ChannelT, typename ErrorInfoT, typename SerializeFtor,
    444           typename DeserializeFtor>
    445 void registerErrorSerialization(std::string Name, SerializeFtor &&Serialize,
    446                                 DeserializeFtor &&Deserialize) {
    447   SerializationTraits<ChannelT, Error>::template registerErrorType<ErrorInfoT>(
    448     std::move(Name),
    449     std::forward<SerializeFtor>(Serialize),
    450     std::forward<DeserializeFtor>(Deserialize));
    451 }
    452 
    453 /// Registers serialization/deserialization for StringError.
    454 template <typename ChannelT>
    455 void registerStringError() {
    456   static bool AlreadyRegistered = false;
    457   if (!AlreadyRegistered) {
    458     registerErrorSerialization<ChannelT, StringError>(
    459       "StringError",
    460       [](ChannelT &C, const StringError &SE) {
    461         return serializeSeq(C, SE.getMessage());
    462       },
    463       [](ChannelT &C, Error &Err) -> Error {
    464         ErrorAsOutParameter EAO(&Err);
    465         std::string Msg;
    466         if (auto E2 = deserializeSeq(C, Msg))
    467           return E2;
    468         Err =
    469           make_error<StringError>(std::move(Msg),
    470                                   orcError(
    471                                     OrcErrorCode::UnknownErrorCodeFromRemote));
    472         return Error::success();
    473       });
    474     AlreadyRegistered = true;
    475   }
    476 }
    477 
    478 /// SerializationTraits for Expected<T1> from an Expected<T2>.
    479 template <typename ChannelT, typename T1, typename T2>
    480 class SerializationTraits<ChannelT, Expected<T1>, Expected<T2>> {
    481 public:
    482 
    483   static Error serialize(ChannelT &C, Expected<T2> &&ValOrErr) {
    484     if (ValOrErr) {
    485       if (auto Err = serializeSeq(C, true))
    486         return Err;
    487       return SerializationTraits<ChannelT, T1, T2>::serialize(C, *ValOrErr);
    488     }
    489     if (auto Err = serializeSeq(C, false))
    490       return Err;
    491     return serializeSeq(C, ValOrErr.takeError());
    492   }
    493 
    494   static Error deserialize(ChannelT &C, Expected<T2> &ValOrErr) {
    495     ExpectedAsOutParameter<T2> EAO(&ValOrErr);
    496     bool HasValue;
    497     if (auto Err = deserializeSeq(C, HasValue))
    498       return Err;
    499     if (HasValue)
    500       return SerializationTraits<ChannelT, T1, T2>::deserialize(C, *ValOrErr);
    501     Error Err = Error::success();
    502     if (auto E2 = deserializeSeq(C, Err))
    503       return E2;
    504     ValOrErr = std::move(Err);
    505     return Error::success();
    506   }
    507 };
    508 
    509 /// SerializationTraits for Expected<T1> from a T2.
    510 template <typename ChannelT, typename T1, typename T2>
    511 class SerializationTraits<ChannelT, Expected<T1>, T2> {
    512 public:
    513 
    514   static Error serialize(ChannelT &C, T2 &&Val) {
    515     return serializeSeq(C, Expected<T2>(std::forward<T2>(Val)));
    516   }
    517 };
    518 
    519 /// SerializationTraits for Expected<T1> from an Error.
    520 template <typename ChannelT, typename T>
    521 class SerializationTraits<ChannelT, Expected<T>, Error> {
    522 public:
    523 
    524   static Error serialize(ChannelT &C, Error &&Err) {
    525     return serializeSeq(C, Expected<T>(std::move(Err)));
    526   }
    527 };
    528 
    529 /// SerializationTraits default specialization for std::pair.
    530 template <typename ChannelT, typename T1, typename T2>
    531 class SerializationTraits<ChannelT, std::pair<T1, T2>> {
    532 public:
    533   static Error serialize(ChannelT &C, const std::pair<T1, T2> &V) {
    534     return serializeSeq(C, V.first, V.second);
    535   }
    536 
    537   static Error deserialize(ChannelT &C, std::pair<T1, T2> &V) {
    538     return deserializeSeq(C, V.first, V.second);
    539   }
    540 };
    541 
    542 /// SerializationTraits default specialization for std::tuple.
    543 template <typename ChannelT, typename... ArgTs>
    544 class SerializationTraits<ChannelT, std::tuple<ArgTs...>> {
    545 public:
    546 
    547   /// RPC channel serialization for std::tuple.
    548   static Error serialize(ChannelT &C, const std::tuple<ArgTs...> &V) {
    549     return serializeTupleHelper(C, V, llvm::index_sequence_for<ArgTs...>());
    550   }
    551 
    552   /// RPC channel deserialization for std::tuple.
    553   static Error deserialize(ChannelT &C, std::tuple<ArgTs...> &V) {
    554     return deserializeTupleHelper(C, V, llvm::index_sequence_for<ArgTs...>());
    555   }
    556 
    557 private:
    558   // Serialization helper for std::tuple.
    559   template <size_t... Is>
    560   static Error serializeTupleHelper(ChannelT &C, const std::tuple<ArgTs...> &V,
    561                                     llvm::index_sequence<Is...> _) {
    562     return serializeSeq(C, std::get<Is>(V)...);
    563   }
    564 
    565   // Serialization helper for std::tuple.
    566   template <size_t... Is>
    567   static Error deserializeTupleHelper(ChannelT &C, std::tuple<ArgTs...> &V,
    568                                       llvm::index_sequence<Is...> _) {
    569     return deserializeSeq(C, std::get<Is>(V)...);
    570   }
    571 };
    572 
    573 /// SerializationTraits default specialization for std::vector.
    574 template <typename ChannelT, typename T>
    575 class SerializationTraits<ChannelT, std::vector<T>> {
    576 public:
    577 
    578   /// Serialize a std::vector<T> from std::vector<T>.
    579   static Error serialize(ChannelT &C, const std::vector<T> &V) {
    580     if (auto Err = serializeSeq(C, static_cast<uint64_t>(V.size())))
    581       return Err;
    582 
    583     for (const auto &E : V)
    584       if (auto Err = serializeSeq(C, E))
    585         return Err;
    586 
    587     return Error::success();
    588   }
    589 
    590   /// Deserialize a std::vector<T> to a std::vector<T>.
    591   static Error deserialize(ChannelT &C, std::vector<T> &V) {
    592     uint64_t Count = 0;
    593     if (auto Err = deserializeSeq(C, Count))
    594       return Err;
    595 
    596     V.resize(Count);
    597     for (auto &E : V)
    598       if (auto Err = deserializeSeq(C, E))
    599         return Err;
    600 
    601     return Error::success();
    602   }
    603 };
    604 
    605 } // end namespace rpc
    606 } // end namespace orc
    607 } // end namespace llvm
    608 
    609 #endif // LLVM_EXECUTIONENGINE_ORC_RPCSERIALIZATION_H
    610