1 //===- llvm/ExecutionEngine/Orc/RPCSerialization.h --------------*- C++ -*-===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 10 #ifndef LLVM_EXECUTIONENGINE_ORC_RPCSERIALIZATION_H 11 #define LLVM_EXECUTIONENGINE_ORC_RPCSERIALIZATION_H 12 13 #include "OrcError.h" 14 #include "llvm/Support/thread.h" 15 #include <map> 16 #include <mutex> 17 #include <sstream> 18 19 namespace llvm { 20 namespace orc { 21 namespace rpc { 22 23 template <typename T> 24 class RPCTypeName; 25 26 /// TypeNameSequence is a utility for rendering sequences of types to a string 27 /// by rendering each type, separated by ", ". 28 template <typename... ArgTs> class RPCTypeNameSequence {}; 29 30 /// Render an empty TypeNameSequence to an ostream. 31 template <typename OStream> 32 OStream &operator<<(OStream &OS, const RPCTypeNameSequence<> &V) { 33 return OS; 34 } 35 36 /// Render a TypeNameSequence of a single type to an ostream. 37 template <typename OStream, typename ArgT> 38 OStream &operator<<(OStream &OS, const RPCTypeNameSequence<ArgT> &V) { 39 OS << RPCTypeName<ArgT>::getName(); 40 return OS; 41 } 42 43 /// Render a TypeNameSequence of more than one type to an ostream. 44 template <typename OStream, typename ArgT1, typename ArgT2, typename... ArgTs> 45 OStream& 46 operator<<(OStream &OS, const RPCTypeNameSequence<ArgT1, ArgT2, ArgTs...> &V) { 47 OS << RPCTypeName<ArgT1>::getName() << ", " 48 << RPCTypeNameSequence<ArgT2, ArgTs...>(); 49 return OS; 50 } 51 52 template <> 53 class RPCTypeName<void> { 54 public: 55 static const char* getName() { return "void"; } 56 }; 57 58 template <> 59 class RPCTypeName<int8_t> { 60 public: 61 static const char* getName() { return "int8_t"; } 62 }; 63 64 template <> 65 class RPCTypeName<uint8_t> { 66 public: 67 static const char* getName() { return "uint8_t"; } 68 }; 69 70 template <> 71 class RPCTypeName<int16_t> { 72 public: 73 static const char* getName() { return "int16_t"; } 74 }; 75 76 template <> 77 class RPCTypeName<uint16_t> { 78 public: 79 static const char* getName() { return "uint16_t"; } 80 }; 81 82 template <> 83 class RPCTypeName<int32_t> { 84 public: 85 static const char* getName() { return "int32_t"; } 86 }; 87 88 template <> 89 class RPCTypeName<uint32_t> { 90 public: 91 static const char* getName() { return "uint32_t"; } 92 }; 93 94 template <> 95 class RPCTypeName<int64_t> { 96 public: 97 static const char* getName() { return "int64_t"; } 98 }; 99 100 template <> 101 class RPCTypeName<uint64_t> { 102 public: 103 static const char* getName() { return "uint64_t"; } 104 }; 105 106 template <> 107 class RPCTypeName<bool> { 108 public: 109 static const char* getName() { return "bool"; } 110 }; 111 112 template <> 113 class RPCTypeName<std::string> { 114 public: 115 static const char* getName() { return "std::string"; } 116 }; 117 118 template <> 119 class RPCTypeName<Error> { 120 public: 121 static const char* getName() { return "Error"; } 122 }; 123 124 template <typename T> 125 class RPCTypeName<Expected<T>> { 126 public: 127 static const char* getName() { 128 std::lock_guard<std::mutex> Lock(NameMutex); 129 if (Name.empty()) 130 raw_string_ostream(Name) << "Expected<" 131 << RPCTypeNameSequence<T>() 132 << ">"; 133 return Name.data(); 134 } 135 136 private: 137 static std::mutex NameMutex; 138 static std::string Name; 139 }; 140 141 template <typename T> 142 std::mutex RPCTypeName<Expected<T>>::NameMutex; 143 144 template <typename T> 145 std::string RPCTypeName<Expected<T>>::Name; 146 147 template <typename T1, typename T2> 148 class RPCTypeName<std::pair<T1, T2>> { 149 public: 150 static const char* getName() { 151 std::lock_guard<std::mutex> Lock(NameMutex); 152 if (Name.empty()) 153 raw_string_ostream(Name) << "std::pair<" << RPCTypeNameSequence<T1, T2>() 154 << ">"; 155 return Name.data(); 156 } 157 private: 158 static std::mutex NameMutex; 159 static std::string Name; 160 }; 161 162 template <typename T1, typename T2> 163 std::mutex RPCTypeName<std::pair<T1, T2>>::NameMutex; 164 template <typename T1, typename T2> 165 std::string RPCTypeName<std::pair<T1, T2>>::Name; 166 167 template <typename... ArgTs> 168 class RPCTypeName<std::tuple<ArgTs...>> { 169 public: 170 static const char* getName() { 171 std::lock_guard<std::mutex> Lock(NameMutex); 172 if (Name.empty()) 173 raw_string_ostream(Name) << "std::tuple<" 174 << RPCTypeNameSequence<ArgTs...>() << ">"; 175 return Name.data(); 176 } 177 private: 178 static std::mutex NameMutex; 179 static std::string Name; 180 }; 181 182 template <typename... ArgTs> 183 std::mutex RPCTypeName<std::tuple<ArgTs...>>::NameMutex; 184 template <typename... ArgTs> 185 std::string RPCTypeName<std::tuple<ArgTs...>>::Name; 186 187 template <typename T> 188 class RPCTypeName<std::vector<T>> { 189 public: 190 static const char*getName() { 191 std::lock_guard<std::mutex> Lock(NameMutex); 192 if (Name.empty()) 193 raw_string_ostream(Name) << "std::vector<" << RPCTypeName<T>::getName() 194 << ">"; 195 return Name.data(); 196 } 197 198 private: 199 static std::mutex NameMutex; 200 static std::string Name; 201 }; 202 203 template <typename T> 204 std::mutex RPCTypeName<std::vector<T>>::NameMutex; 205 template <typename T> 206 std::string RPCTypeName<std::vector<T>>::Name; 207 208 209 /// The SerializationTraits<ChannelT, T> class describes how to serialize and 210 /// deserialize an instance of type T to/from an abstract channel of type 211 /// ChannelT. It also provides a representation of the type's name via the 212 /// getName method. 213 /// 214 /// Specializations of this class should provide the following functions: 215 /// 216 /// @code{.cpp} 217 /// 218 /// static const char* getName(); 219 /// static Error serialize(ChannelT&, const T&); 220 /// static Error deserialize(ChannelT&, T&); 221 /// 222 /// @endcode 223 /// 224 /// The third argument of SerializationTraits is intended to support SFINAE. 225 /// E.g.: 226 /// 227 /// @code{.cpp} 228 /// 229 /// class MyVirtualChannel { ... }; 230 /// 231 /// template <DerivedChannelT> 232 /// class SerializationTraits<DerivedChannelT, bool, 233 /// typename std::enable_if< 234 /// std::is_base_of<VirtChannel, DerivedChannel>::value 235 /// >::type> { 236 /// public: 237 /// static const char* getName() { ... }; 238 /// } 239 /// 240 /// @endcode 241 template <typename ChannelT, typename WireType, 242 typename ConcreteType = WireType, typename = void> 243 class SerializationTraits; 244 245 template <typename ChannelT> 246 class SequenceTraits { 247 public: 248 static Error emitSeparator(ChannelT &C) { return Error::success(); } 249 static Error consumeSeparator(ChannelT &C) { return Error::success(); } 250 }; 251 252 /// Utility class for serializing sequences of values of varying types. 253 /// Specializations of this class contain 'serialize' and 'deserialize' methods 254 /// for the given channel. The ArgTs... list will determine the "over-the-wire" 255 /// types to be serialized. The serialize and deserialize methods take a list 256 /// CArgTs... ("caller arg types") which must be the same length as ArgTs..., 257 /// but may be different types from ArgTs, provided that for each CArgT there 258 /// is a SerializationTraits specialization 259 /// SerializeTraits<ChannelT, ArgT, CArgT> with methods that can serialize the 260 /// caller argument to over-the-wire value. 261 template <typename ChannelT, typename... ArgTs> 262 class SequenceSerialization; 263 264 template <typename ChannelT> 265 class SequenceSerialization<ChannelT> { 266 public: 267 static Error serialize(ChannelT &C) { return Error::success(); } 268 static Error deserialize(ChannelT &C) { return Error::success(); } 269 }; 270 271 template <typename ChannelT, typename ArgT> 272 class SequenceSerialization<ChannelT, ArgT> { 273 public: 274 275 template <typename CArgT> 276 static Error serialize(ChannelT &C, CArgT &&CArg) { 277 return SerializationTraits<ChannelT, ArgT, 278 typename std::decay<CArgT>::type>:: 279 serialize(C, std::forward<CArgT>(CArg)); 280 } 281 282 template <typename CArgT> 283 static Error deserialize(ChannelT &C, CArgT &CArg) { 284 return SerializationTraits<ChannelT, ArgT, CArgT>::deserialize(C, CArg); 285 } 286 }; 287 288 template <typename ChannelT, typename ArgT, typename... ArgTs> 289 class SequenceSerialization<ChannelT, ArgT, ArgTs...> { 290 public: 291 292 template <typename CArgT, typename... CArgTs> 293 static Error serialize(ChannelT &C, CArgT &&CArg, 294 CArgTs &&... CArgs) { 295 if (auto Err = 296 SerializationTraits<ChannelT, ArgT, typename std::decay<CArgT>::type>:: 297 serialize(C, std::forward<CArgT>(CArg))) 298 return Err; 299 if (auto Err = SequenceTraits<ChannelT>::emitSeparator(C)) 300 return Err; 301 return SequenceSerialization<ChannelT, ArgTs...>:: 302 serialize(C, std::forward<CArgTs>(CArgs)...); 303 } 304 305 template <typename CArgT, typename... CArgTs> 306 static Error deserialize(ChannelT &C, CArgT &CArg, 307 CArgTs &... CArgs) { 308 if (auto Err = 309 SerializationTraits<ChannelT, ArgT, CArgT>::deserialize(C, CArg)) 310 return Err; 311 if (auto Err = SequenceTraits<ChannelT>::consumeSeparator(C)) 312 return Err; 313 return SequenceSerialization<ChannelT, ArgTs...>::deserialize(C, CArgs...); 314 } 315 }; 316 317 template <typename ChannelT, typename... ArgTs> 318 Error serializeSeq(ChannelT &C, ArgTs &&... Args) { 319 return SequenceSerialization<ChannelT, typename std::decay<ArgTs>::type...>:: 320 serialize(C, std::forward<ArgTs>(Args)...); 321 } 322 323 template <typename ChannelT, typename... ArgTs> 324 Error deserializeSeq(ChannelT &C, ArgTs &... Args) { 325 return SequenceSerialization<ChannelT, ArgTs...>::deserialize(C, Args...); 326 } 327 328 template <typename ChannelT> 329 class SerializationTraits<ChannelT, Error> { 330 public: 331 332 using WrappedErrorSerializer = 333 std::function<Error(ChannelT &C, const ErrorInfoBase&)>; 334 335 using WrappedErrorDeserializer = 336 std::function<Error(ChannelT &C, Error &Err)>; 337 338 template <typename ErrorInfoT, typename SerializeFtor, 339 typename DeserializeFtor> 340 static void registerErrorType(std::string Name, SerializeFtor Serialize, 341 DeserializeFtor Deserialize) { 342 assert(!Name.empty() && 343 "The empty string is reserved for the Success value"); 344 345 const std::string *KeyName = nullptr; 346 { 347 // We're abusing the stability of std::map here: We take a reference to the 348 // key of the deserializers map to save us from duplicating the string in 349 // the serializer. This should be changed to use a stringpool if we switch 350 // to a map type that may move keys in memory. 351 std::lock_guard<std::recursive_mutex> Lock(DeserializersMutex); 352 auto I = 353 Deserializers.insert(Deserializers.begin(), 354 std::make_pair(std::move(Name), 355 std::move(Deserialize))); 356 KeyName = &I->first; 357 } 358 359 { 360 assert(KeyName != nullptr && "No keyname pointer"); 361 std::lock_guard<std::recursive_mutex> Lock(SerializersMutex); 362 // FIXME: Move capture Serialize once we have C++14. 363 Serializers[ErrorInfoT::classID()] = 364 [KeyName, Serialize](ChannelT &C, const ErrorInfoBase &EIB) -> Error { 365 assert(EIB.dynamicClassID() == ErrorInfoT::classID() && 366 "Serializer called for wrong error type"); 367 if (auto Err = serializeSeq(C, *KeyName)) 368 return Err; 369 return Serialize(C, static_cast<const ErrorInfoT &>(EIB)); 370 }; 371 } 372 } 373 374 static Error serialize(ChannelT &C, Error &&Err) { 375 std::lock_guard<std::recursive_mutex> Lock(SerializersMutex); 376 377 if (!Err) 378 return serializeSeq(C, std::string()); 379 380 return handleErrors(std::move(Err), 381 [&C](const ErrorInfoBase &EIB) { 382 auto SI = Serializers.find(EIB.dynamicClassID()); 383 if (SI == Serializers.end()) 384 return serializeAsStringError(C, EIB); 385 return (SI->second)(C, EIB); 386 }); 387 } 388 389 static Error deserialize(ChannelT &C, Error &Err) { 390 std::lock_guard<std::recursive_mutex> Lock(DeserializersMutex); 391 392 std::string Key; 393 if (auto Err = deserializeSeq(C, Key)) 394 return Err; 395 396 if (Key.empty()) { 397 ErrorAsOutParameter EAO(&Err); 398 Err = Error::success(); 399 return Error::success(); 400 } 401 402 auto DI = Deserializers.find(Key); 403 assert(DI != Deserializers.end() && "No deserializer for error type"); 404 return (DI->second)(C, Err); 405 } 406 407 private: 408 409 static Error serializeAsStringError(ChannelT &C, const ErrorInfoBase &EIB) { 410 std::string ErrMsg; 411 { 412 raw_string_ostream ErrMsgStream(ErrMsg); 413 EIB.log(ErrMsgStream); 414 } 415 return serialize(C, make_error<StringError>(std::move(ErrMsg), 416 inconvertibleErrorCode())); 417 } 418 419 static std::recursive_mutex SerializersMutex; 420 static std::recursive_mutex DeserializersMutex; 421 static std::map<const void*, WrappedErrorSerializer> Serializers; 422 static std::map<std::string, WrappedErrorDeserializer> Deserializers; 423 }; 424 425 template <typename ChannelT> 426 std::recursive_mutex SerializationTraits<ChannelT, Error>::SerializersMutex; 427 428 template <typename ChannelT> 429 std::recursive_mutex SerializationTraits<ChannelT, Error>::DeserializersMutex; 430 431 template <typename ChannelT> 432 std::map<const void*, 433 typename SerializationTraits<ChannelT, Error>::WrappedErrorSerializer> 434 SerializationTraits<ChannelT, Error>::Serializers; 435 436 template <typename ChannelT> 437 std::map<std::string, 438 typename SerializationTraits<ChannelT, Error>::WrappedErrorDeserializer> 439 SerializationTraits<ChannelT, Error>::Deserializers; 440 441 /// Registers a serializer and deserializer for the given error type on the 442 /// given channel type. 443 template <typename ChannelT, typename ErrorInfoT, typename SerializeFtor, 444 typename DeserializeFtor> 445 void registerErrorSerialization(std::string Name, SerializeFtor &&Serialize, 446 DeserializeFtor &&Deserialize) { 447 SerializationTraits<ChannelT, Error>::template registerErrorType<ErrorInfoT>( 448 std::move(Name), 449 std::forward<SerializeFtor>(Serialize), 450 std::forward<DeserializeFtor>(Deserialize)); 451 } 452 453 /// Registers serialization/deserialization for StringError. 454 template <typename ChannelT> 455 void registerStringError() { 456 static bool AlreadyRegistered = false; 457 if (!AlreadyRegistered) { 458 registerErrorSerialization<ChannelT, StringError>( 459 "StringError", 460 [](ChannelT &C, const StringError &SE) { 461 return serializeSeq(C, SE.getMessage()); 462 }, 463 [](ChannelT &C, Error &Err) -> Error { 464 ErrorAsOutParameter EAO(&Err); 465 std::string Msg; 466 if (auto E2 = deserializeSeq(C, Msg)) 467 return E2; 468 Err = 469 make_error<StringError>(std::move(Msg), 470 orcError( 471 OrcErrorCode::UnknownErrorCodeFromRemote)); 472 return Error::success(); 473 }); 474 AlreadyRegistered = true; 475 } 476 } 477 478 /// SerializationTraits for Expected<T1> from an Expected<T2>. 479 template <typename ChannelT, typename T1, typename T2> 480 class SerializationTraits<ChannelT, Expected<T1>, Expected<T2>> { 481 public: 482 483 static Error serialize(ChannelT &C, Expected<T2> &&ValOrErr) { 484 if (ValOrErr) { 485 if (auto Err = serializeSeq(C, true)) 486 return Err; 487 return SerializationTraits<ChannelT, T1, T2>::serialize(C, *ValOrErr); 488 } 489 if (auto Err = serializeSeq(C, false)) 490 return Err; 491 return serializeSeq(C, ValOrErr.takeError()); 492 } 493 494 static Error deserialize(ChannelT &C, Expected<T2> &ValOrErr) { 495 ExpectedAsOutParameter<T2> EAO(&ValOrErr); 496 bool HasValue; 497 if (auto Err = deserializeSeq(C, HasValue)) 498 return Err; 499 if (HasValue) 500 return SerializationTraits<ChannelT, T1, T2>::deserialize(C, *ValOrErr); 501 Error Err = Error::success(); 502 if (auto E2 = deserializeSeq(C, Err)) 503 return E2; 504 ValOrErr = std::move(Err); 505 return Error::success(); 506 } 507 }; 508 509 /// SerializationTraits for Expected<T1> from a T2. 510 template <typename ChannelT, typename T1, typename T2> 511 class SerializationTraits<ChannelT, Expected<T1>, T2> { 512 public: 513 514 static Error serialize(ChannelT &C, T2 &&Val) { 515 return serializeSeq(C, Expected<T2>(std::forward<T2>(Val))); 516 } 517 }; 518 519 /// SerializationTraits for Expected<T1> from an Error. 520 template <typename ChannelT, typename T> 521 class SerializationTraits<ChannelT, Expected<T>, Error> { 522 public: 523 524 static Error serialize(ChannelT &C, Error &&Err) { 525 return serializeSeq(C, Expected<T>(std::move(Err))); 526 } 527 }; 528 529 /// SerializationTraits default specialization for std::pair. 530 template <typename ChannelT, typename T1, typename T2> 531 class SerializationTraits<ChannelT, std::pair<T1, T2>> { 532 public: 533 static Error serialize(ChannelT &C, const std::pair<T1, T2> &V) { 534 return serializeSeq(C, V.first, V.second); 535 } 536 537 static Error deserialize(ChannelT &C, std::pair<T1, T2> &V) { 538 return deserializeSeq(C, V.first, V.second); 539 } 540 }; 541 542 /// SerializationTraits default specialization for std::tuple. 543 template <typename ChannelT, typename... ArgTs> 544 class SerializationTraits<ChannelT, std::tuple<ArgTs...>> { 545 public: 546 547 /// RPC channel serialization for std::tuple. 548 static Error serialize(ChannelT &C, const std::tuple<ArgTs...> &V) { 549 return serializeTupleHelper(C, V, llvm::index_sequence_for<ArgTs...>()); 550 } 551 552 /// RPC channel deserialization for std::tuple. 553 static Error deserialize(ChannelT &C, std::tuple<ArgTs...> &V) { 554 return deserializeTupleHelper(C, V, llvm::index_sequence_for<ArgTs...>()); 555 } 556 557 private: 558 // Serialization helper for std::tuple. 559 template <size_t... Is> 560 static Error serializeTupleHelper(ChannelT &C, const std::tuple<ArgTs...> &V, 561 llvm::index_sequence<Is...> _) { 562 return serializeSeq(C, std::get<Is>(V)...); 563 } 564 565 // Serialization helper for std::tuple. 566 template <size_t... Is> 567 static Error deserializeTupleHelper(ChannelT &C, std::tuple<ArgTs...> &V, 568 llvm::index_sequence<Is...> _) { 569 return deserializeSeq(C, std::get<Is>(V)...); 570 } 571 }; 572 573 /// SerializationTraits default specialization for std::vector. 574 template <typename ChannelT, typename T> 575 class SerializationTraits<ChannelT, std::vector<T>> { 576 public: 577 578 /// Serialize a std::vector<T> from std::vector<T>. 579 static Error serialize(ChannelT &C, const std::vector<T> &V) { 580 if (auto Err = serializeSeq(C, static_cast<uint64_t>(V.size()))) 581 return Err; 582 583 for (const auto &E : V) 584 if (auto Err = serializeSeq(C, E)) 585 return Err; 586 587 return Error::success(); 588 } 589 590 /// Deserialize a std::vector<T> to a std::vector<T>. 591 static Error deserialize(ChannelT &C, std::vector<T> &V) { 592 uint64_t Count = 0; 593 if (auto Err = deserializeSeq(C, Count)) 594 return Err; 595 596 V.resize(Count); 597 for (auto &E : V) 598 if (auto Err = deserializeSeq(C, E)) 599 return Err; 600 601 return Error::success(); 602 } 603 }; 604 605 } // end namespace rpc 606 } // end namespace orc 607 } // end namespace llvm 608 609 #endif // LLVM_EXECUTIONENGINE_ORC_RPCSERIALIZATION_H 610