1 //===------- RPCUTils.h - Utilities for building RPC APIs -------*- C++ -*-===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // Utilities to support construction of simple RPC APIs. 11 // 12 // The RPC utilities aim for ease of use (minimal conceptual overhead) for C++ 13 // programmers, high performance, low memory overhead, and efficient use of the 14 // communications channel. 15 // 16 //===----------------------------------------------------------------------===// 17 18 #ifndef LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H 19 #define LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H 20 21 #include <map> 22 #include <thread> 23 #include <vector> 24 25 #include "llvm/ADT/STLExtras.h" 26 #include "llvm/ExecutionEngine/Orc/OrcError.h" 27 #include "llvm/ExecutionEngine/Orc/RPCSerialization.h" 28 29 #include <future> 30 31 namespace llvm { 32 namespace orc { 33 namespace rpc { 34 35 /// Base class of all fatal RPC errors (those that necessarily result in the 36 /// termination of the RPC session). 37 class RPCFatalError : public ErrorInfo<RPCFatalError> { 38 public: 39 static char ID; 40 }; 41 42 /// RPCConnectionClosed is returned from RPC operations if the RPC connection 43 /// has already been closed due to either an error or graceful disconnection. 44 class ConnectionClosed : public ErrorInfo<ConnectionClosed> { 45 public: 46 static char ID; 47 std::error_code convertToErrorCode() const override; 48 void log(raw_ostream &OS) const override; 49 }; 50 51 /// BadFunctionCall is returned from handleOne when the remote makes a call with 52 /// an unrecognized function id. 53 /// 54 /// This error is fatal because Orc RPC needs to know how to parse a function 55 /// call to know where the next call starts, and if it doesn't recognize the 56 /// function id it cannot parse the call. 57 template <typename FnIdT, typename SeqNoT> 58 class BadFunctionCall 59 : public ErrorInfo<BadFunctionCall<FnIdT, SeqNoT>, RPCFatalError> { 60 public: 61 static char ID; 62 63 BadFunctionCall(FnIdT FnId, SeqNoT SeqNo) 64 : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {} 65 66 std::error_code convertToErrorCode() const override { 67 return orcError(OrcErrorCode::UnexpectedRPCCall); 68 } 69 70 void log(raw_ostream &OS) const override { 71 OS << "Call to invalid RPC function id '" << FnId << "' with " 72 "sequence number " << SeqNo; 73 } 74 75 private: 76 FnIdT FnId; 77 SeqNoT SeqNo; 78 }; 79 80 template <typename FnIdT, typename SeqNoT> 81 char BadFunctionCall<FnIdT, SeqNoT>::ID = 0; 82 83 /// InvalidSequenceNumberForResponse is returned from handleOne when a response 84 /// call arrives with a sequence number that doesn't correspond to any in-flight 85 /// function call. 86 /// 87 /// This error is fatal because Orc RPC needs to know how to parse the rest of 88 /// the response call to know where the next call starts, and if it doesn't have 89 /// a result parser for this sequence number it can't do that. 90 template <typename SeqNoT> 91 class InvalidSequenceNumberForResponse 92 : public ErrorInfo<InvalidSequenceNumberForResponse<SeqNoT>, RPCFatalError> { 93 public: 94 static char ID; 95 96 InvalidSequenceNumberForResponse(SeqNoT SeqNo) 97 : SeqNo(std::move(SeqNo)) {} 98 99 std::error_code convertToErrorCode() const override { 100 return orcError(OrcErrorCode::UnexpectedRPCCall); 101 }; 102 103 void log(raw_ostream &OS) const override { 104 OS << "Response has unknown sequence number " << SeqNo; 105 } 106 private: 107 SeqNoT SeqNo; 108 }; 109 110 template <typename SeqNoT> 111 char InvalidSequenceNumberForResponse<SeqNoT>::ID = 0; 112 113 /// This non-fatal error will be passed to asynchronous result handlers in place 114 /// of a result if the connection goes down before a result returns, or if the 115 /// function to be called cannot be negotiated with the remote. 116 class ResponseAbandoned : public ErrorInfo<ResponseAbandoned> { 117 public: 118 static char ID; 119 120 std::error_code convertToErrorCode() const override; 121 void log(raw_ostream &OS) const override; 122 }; 123 124 /// This error is returned if the remote does not have a handler installed for 125 /// the given RPC function. 126 class CouldNotNegotiate : public ErrorInfo<CouldNotNegotiate> { 127 public: 128 static char ID; 129 130 CouldNotNegotiate(std::string Signature); 131 std::error_code convertToErrorCode() const override; 132 void log(raw_ostream &OS) const override; 133 const std::string &getSignature() const { return Signature; } 134 private: 135 std::string Signature; 136 }; 137 138 template <typename DerivedFunc, typename FnT> class Function; 139 140 // RPC Function class. 141 // DerivedFunc should be a user defined class with a static 'getName()' method 142 // returning a const char* representing the function's name. 143 template <typename DerivedFunc, typename RetT, typename... ArgTs> 144 class Function<DerivedFunc, RetT(ArgTs...)> { 145 public: 146 /// User defined function type. 147 using Type = RetT(ArgTs...); 148 149 /// Return type. 150 using ReturnType = RetT; 151 152 /// Returns the full function prototype as a string. 153 static const char *getPrototype() { 154 std::lock_guard<std::mutex> Lock(NameMutex); 155 if (Name.empty()) 156 raw_string_ostream(Name) 157 << RPCTypeName<RetT>::getName() << " " << DerivedFunc::getName() 158 << "(" << llvm::orc::rpc::RPCTypeNameSequence<ArgTs...>() << ")"; 159 return Name.data(); 160 } 161 162 private: 163 static std::mutex NameMutex; 164 static std::string Name; 165 }; 166 167 template <typename DerivedFunc, typename RetT, typename... ArgTs> 168 std::mutex Function<DerivedFunc, RetT(ArgTs...)>::NameMutex; 169 170 template <typename DerivedFunc, typename RetT, typename... ArgTs> 171 std::string Function<DerivedFunc, RetT(ArgTs...)>::Name; 172 173 /// Allocates RPC function ids during autonegotiation. 174 /// Specializations of this class must provide four members: 175 /// 176 /// static T getInvalidId(): 177 /// Should return a reserved id that will be used to represent missing 178 /// functions during autonegotiation. 179 /// 180 /// static T getResponseId(): 181 /// Should return a reserved id that will be used to send function responses 182 /// (return values). 183 /// 184 /// static T getNegotiateId(): 185 /// Should return a reserved id for the negotiate function, which will be used 186 /// to negotiate ids for user defined functions. 187 /// 188 /// template <typename Func> T allocate(): 189 /// Allocate a unique id for function Func. 190 template <typename T, typename = void> class RPCFunctionIdAllocator; 191 192 /// This specialization of RPCFunctionIdAllocator provides a default 193 /// implementation for integral types. 194 template <typename T> 195 class RPCFunctionIdAllocator< 196 T, typename std::enable_if<std::is_integral<T>::value>::type> { 197 public: 198 static T getInvalidId() { return T(0); } 199 static T getResponseId() { return T(1); } 200 static T getNegotiateId() { return T(2); } 201 202 template <typename Func> T allocate() { return NextId++; } 203 204 private: 205 T NextId = 3; 206 }; 207 208 namespace detail { 209 210 // FIXME: Remove MSVCPError/MSVCPExpected once MSVC's future implementation 211 // supports classes without default constructors. 212 #ifdef _MSC_VER 213 214 namespace msvc_hacks { 215 216 // Work around MSVC's future implementation's use of default constructors: 217 // A default constructed value in the promise will be overwritten when the 218 // real error is set - so the default constructed Error has to be checked 219 // already. 220 class MSVCPError : public Error { 221 public: 222 MSVCPError() { (void)!!*this; } 223 224 MSVCPError(MSVCPError &&Other) : Error(std::move(Other)) {} 225 226 MSVCPError &operator=(MSVCPError Other) { 227 Error::operator=(std::move(Other)); 228 return *this; 229 } 230 231 MSVCPError(Error Err) : Error(std::move(Err)) {} 232 }; 233 234 // Work around MSVC's future implementation, similar to MSVCPError. 235 template <typename T> class MSVCPExpected : public Expected<T> { 236 public: 237 MSVCPExpected() 238 : Expected<T>(make_error<StringError>("", inconvertibleErrorCode())) { 239 consumeError(this->takeError()); 240 } 241 242 MSVCPExpected(MSVCPExpected &&Other) : Expected<T>(std::move(Other)) {} 243 244 MSVCPExpected &operator=(MSVCPExpected &&Other) { 245 Expected<T>::operator=(std::move(Other)); 246 return *this; 247 } 248 249 MSVCPExpected(Error Err) : Expected<T>(std::move(Err)) {} 250 251 template <typename OtherT> 252 MSVCPExpected( 253 OtherT &&Val, 254 typename std::enable_if<std::is_convertible<OtherT, T>::value>::type * = 255 nullptr) 256 : Expected<T>(std::move(Val)) {} 257 258 template <class OtherT> 259 MSVCPExpected( 260 Expected<OtherT> &&Other, 261 typename std::enable_if<std::is_convertible<OtherT, T>::value>::type * = 262 nullptr) 263 : Expected<T>(std::move(Other)) {} 264 265 template <class OtherT> 266 explicit MSVCPExpected( 267 Expected<OtherT> &&Other, 268 typename std::enable_if<!std::is_convertible<OtherT, T>::value>::type * = 269 nullptr) 270 : Expected<T>(std::move(Other)) {} 271 }; 272 273 } // end namespace msvc_hacks 274 275 #endif // _MSC_VER 276 277 /// Provides a typedef for a tuple containing the decayed argument types. 278 template <typename T> class FunctionArgsTuple; 279 280 template <typename RetT, typename... ArgTs> 281 class FunctionArgsTuple<RetT(ArgTs...)> { 282 public: 283 using Type = std::tuple<typename std::decay< 284 typename std::remove_reference<ArgTs>::type>::type...>; 285 }; 286 287 // ResultTraits provides typedefs and utilities specific to the return type 288 // of functions. 289 template <typename RetT> class ResultTraits { 290 public: 291 // The return type wrapped in llvm::Expected. 292 using ErrorReturnType = Expected<RetT>; 293 294 #ifdef _MSC_VER 295 // The ErrorReturnType wrapped in a std::promise. 296 using ReturnPromiseType = std::promise<msvc_hacks::MSVCPExpected<RetT>>; 297 298 // The ErrorReturnType wrapped in a std::future. 299 using ReturnFutureType = std::future<msvc_hacks::MSVCPExpected<RetT>>; 300 #else 301 // The ErrorReturnType wrapped in a std::promise. 302 using ReturnPromiseType = std::promise<ErrorReturnType>; 303 304 // The ErrorReturnType wrapped in a std::future. 305 using ReturnFutureType = std::future<ErrorReturnType>; 306 #endif 307 308 // Create a 'blank' value of the ErrorReturnType, ready and safe to 309 // overwrite. 310 static ErrorReturnType createBlankErrorReturnValue() { 311 return ErrorReturnType(RetT()); 312 } 313 314 // Consume an abandoned ErrorReturnType. 315 static void consumeAbandoned(ErrorReturnType RetOrErr) { 316 consumeError(RetOrErr.takeError()); 317 } 318 }; 319 320 // ResultTraits specialization for void functions. 321 template <> class ResultTraits<void> { 322 public: 323 // For void functions, ErrorReturnType is llvm::Error. 324 using ErrorReturnType = Error; 325 326 #ifdef _MSC_VER 327 // The ErrorReturnType wrapped in a std::promise. 328 using ReturnPromiseType = std::promise<msvc_hacks::MSVCPError>; 329 330 // The ErrorReturnType wrapped in a std::future. 331 using ReturnFutureType = std::future<msvc_hacks::MSVCPError>; 332 #else 333 // The ErrorReturnType wrapped in a std::promise. 334 using ReturnPromiseType = std::promise<ErrorReturnType>; 335 336 // The ErrorReturnType wrapped in a std::future. 337 using ReturnFutureType = std::future<ErrorReturnType>; 338 #endif 339 340 // Create a 'blank' value of the ErrorReturnType, ready and safe to 341 // overwrite. 342 static ErrorReturnType createBlankErrorReturnValue() { 343 return ErrorReturnType::success(); 344 } 345 346 // Consume an abandoned ErrorReturnType. 347 static void consumeAbandoned(ErrorReturnType Err) { 348 consumeError(std::move(Err)); 349 } 350 }; 351 352 // ResultTraits<Error> is equivalent to ResultTraits<void>. This allows 353 // handlers for void RPC functions to return either void (in which case they 354 // implicitly succeed) or Error (in which case their error return is 355 // propagated). See usage in HandlerTraits::runHandlerHelper. 356 template <> class ResultTraits<Error> : public ResultTraits<void> {}; 357 358 // ResultTraits<Expected<T>> is equivalent to ResultTraits<T>. This allows 359 // handlers for RPC functions returning a T to return either a T (in which 360 // case they implicitly succeed) or Expected<T> (in which case their error 361 // return is propagated). See usage in HandlerTraits::runHandlerHelper. 362 template <typename RetT> 363 class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {}; 364 365 // Determines whether an RPC function's defined error return type supports 366 // error return value. 367 template <typename T> 368 class SupportsErrorReturn { 369 public: 370 static const bool value = false; 371 }; 372 373 template <> 374 class SupportsErrorReturn<Error> { 375 public: 376 static const bool value = true; 377 }; 378 379 template <typename T> 380 class SupportsErrorReturn<Expected<T>> { 381 public: 382 static const bool value = true; 383 }; 384 385 // RespondHelper packages return values based on whether or not the declared 386 // RPC function return type supports error returns. 387 template <bool FuncSupportsErrorReturn> 388 class RespondHelper; 389 390 // RespondHelper specialization for functions that support error returns. 391 template <> 392 class RespondHelper<true> { 393 public: 394 395 // Send Expected<T>. 396 template <typename WireRetT, typename HandlerRetT, typename ChannelT, 397 typename FunctionIdT, typename SequenceNumberT> 398 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, 399 SequenceNumberT SeqNo, 400 Expected<HandlerRetT> ResultOrErr) { 401 if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>()) 402 return ResultOrErr.takeError(); 403 404 // Open the response message. 405 if (auto Err = C.startSendMessage(ResponseId, SeqNo)) 406 return Err; 407 408 // Serialize the result. 409 if (auto Err = 410 SerializationTraits<ChannelT, WireRetT, 411 Expected<HandlerRetT>>::serialize( 412 C, std::move(ResultOrErr))) 413 return Err; 414 415 // Close the response message. 416 return C.endSendMessage(); 417 } 418 419 template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT> 420 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, 421 SequenceNumberT SeqNo, Error Err) { 422 if (Err && Err.isA<RPCFatalError>()) 423 return Err; 424 if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) 425 return Err2; 426 if (auto Err2 = serializeSeq(C, std::move(Err))) 427 return Err2; 428 return C.endSendMessage(); 429 } 430 431 }; 432 433 // RespondHelper specialization for functions that do not support error returns. 434 template <> 435 class RespondHelper<false> { 436 public: 437 438 template <typename WireRetT, typename HandlerRetT, typename ChannelT, 439 typename FunctionIdT, typename SequenceNumberT> 440 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, 441 SequenceNumberT SeqNo, 442 Expected<HandlerRetT> ResultOrErr) { 443 if (auto Err = ResultOrErr.takeError()) 444 return Err; 445 446 // Open the response message. 447 if (auto Err = C.startSendMessage(ResponseId, SeqNo)) 448 return Err; 449 450 // Serialize the result. 451 if (auto Err = 452 SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize( 453 C, *ResultOrErr)) 454 return Err; 455 456 // Close the response message. 457 return C.endSendMessage(); 458 } 459 460 template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT> 461 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, 462 SequenceNumberT SeqNo, Error Err) { 463 if (Err) 464 return Err; 465 if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) 466 return Err2; 467 return C.endSendMessage(); 468 } 469 470 }; 471 472 473 // Send a response of the given wire return type (WireRetT) over the 474 // channel, with the given sequence number. 475 template <typename WireRetT, typename HandlerRetT, typename ChannelT, 476 typename FunctionIdT, typename SequenceNumberT> 477 Error respond(ChannelT &C, const FunctionIdT &ResponseId, 478 SequenceNumberT SeqNo, Expected<HandlerRetT> ResultOrErr) { 479 return RespondHelper<SupportsErrorReturn<WireRetT>::value>:: 480 template sendResult<WireRetT>(C, ResponseId, SeqNo, std::move(ResultOrErr)); 481 } 482 483 // Send an empty response message on the given channel to indicate that 484 // the handler ran. 485 template <typename WireRetT, typename ChannelT, typename FunctionIdT, 486 typename SequenceNumberT> 487 Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo, 488 Error Err) { 489 return RespondHelper<SupportsErrorReturn<WireRetT>::value>:: 490 sendResult(C, ResponseId, SeqNo, std::move(Err)); 491 } 492 493 // Converts a given type to the equivalent error return type. 494 template <typename T> class WrappedHandlerReturn { 495 public: 496 using Type = Expected<T>; 497 }; 498 499 template <typename T> class WrappedHandlerReturn<Expected<T>> { 500 public: 501 using Type = Expected<T>; 502 }; 503 504 template <> class WrappedHandlerReturn<void> { 505 public: 506 using Type = Error; 507 }; 508 509 template <> class WrappedHandlerReturn<Error> { 510 public: 511 using Type = Error; 512 }; 513 514 template <> class WrappedHandlerReturn<ErrorSuccess> { 515 public: 516 using Type = Error; 517 }; 518 519 // Traits class that strips the response function from the list of handler 520 // arguments. 521 template <typename FnT> class AsyncHandlerTraits; 522 523 template <typename ResultT, typename... ArgTs> 524 class AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>, ArgTs...)> { 525 public: 526 using Type = Error(ArgTs...); 527 using ResultType = Expected<ResultT>; 528 }; 529 530 template <typename... ArgTs> 531 class AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> { 532 public: 533 using Type = Error(ArgTs...); 534 using ResultType = Error; 535 }; 536 537 template <typename... ArgTs> 538 class AsyncHandlerTraits<ErrorSuccess(std::function<Error(Error)>, ArgTs...)> { 539 public: 540 using Type = Error(ArgTs...); 541 using ResultType = Error; 542 }; 543 544 template <typename... ArgTs> 545 class AsyncHandlerTraits<void(std::function<Error(Error)>, ArgTs...)> { 546 public: 547 using Type = Error(ArgTs...); 548 using ResultType = Error; 549 }; 550 551 template <typename ResponseHandlerT, typename... ArgTs> 552 class AsyncHandlerTraits<Error(ResponseHandlerT, ArgTs...)> : 553 public AsyncHandlerTraits<Error(typename std::decay<ResponseHandlerT>::type, 554 ArgTs...)> {}; 555 556 // This template class provides utilities related to RPC function handlers. 557 // The base case applies to non-function types (the template class is 558 // specialized for function types) and inherits from the appropriate 559 // speciilization for the given non-function type's call operator. 560 template <typename HandlerT> 561 class HandlerTraits : public HandlerTraits<decltype( 562 &std::remove_reference<HandlerT>::type::operator())> { 563 }; 564 565 // Traits for handlers with a given function type. 566 template <typename RetT, typename... ArgTs> 567 class HandlerTraits<RetT(ArgTs...)> { 568 public: 569 // Function type of the handler. 570 using Type = RetT(ArgTs...); 571 572 // Return type of the handler. 573 using ReturnType = RetT; 574 575 // Call the given handler with the given arguments. 576 template <typename HandlerT, typename... TArgTs> 577 static typename WrappedHandlerReturn<RetT>::Type 578 unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) { 579 return unpackAndRunHelper(Handler, Args, 580 llvm::index_sequence_for<TArgTs...>()); 581 } 582 583 // Call the given handler with the given arguments. 584 template <typename HandlerT, typename ResponderT, typename... TArgTs> 585 static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder, 586 std::tuple<TArgTs...> &Args) { 587 return unpackAndRunAsyncHelper(Handler, Responder, Args, 588 llvm::index_sequence_for<TArgTs...>()); 589 } 590 591 // Call the given handler with the given arguments. 592 template <typename HandlerT> 593 static typename std::enable_if< 594 std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value, 595 Error>::type 596 run(HandlerT &Handler, ArgTs &&... Args) { 597 Handler(std::move(Args)...); 598 return Error::success(); 599 } 600 601 template <typename HandlerT, typename... TArgTs> 602 static typename std::enable_if< 603 !std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value, 604 typename HandlerTraits<HandlerT>::ReturnType>::type 605 run(HandlerT &Handler, TArgTs... Args) { 606 return Handler(std::move(Args)...); 607 } 608 609 // Serialize arguments to the channel. 610 template <typename ChannelT, typename... CArgTs> 611 static Error serializeArgs(ChannelT &C, const CArgTs... CArgs) { 612 return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...); 613 } 614 615 // Deserialize arguments from the channel. 616 template <typename ChannelT, typename... CArgTs> 617 static Error deserializeArgs(ChannelT &C, std::tuple<CArgTs...> &Args) { 618 return deserializeArgsHelper(C, Args, 619 llvm::index_sequence_for<CArgTs...>()); 620 } 621 622 private: 623 template <typename ChannelT, typename... CArgTs, size_t... Indexes> 624 static Error deserializeArgsHelper(ChannelT &C, std::tuple<CArgTs...> &Args, 625 llvm::index_sequence<Indexes...> _) { 626 return SequenceSerialization<ChannelT, ArgTs...>::deserialize( 627 C, std::get<Indexes>(Args)...); 628 } 629 630 template <typename HandlerT, typename ArgTuple, size_t... Indexes> 631 static typename WrappedHandlerReturn< 632 typename HandlerTraits<HandlerT>::ReturnType>::Type 633 unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args, 634 llvm::index_sequence<Indexes...>) { 635 return run(Handler, std::move(std::get<Indexes>(Args))...); 636 } 637 638 639 template <typename HandlerT, typename ResponderT, typename ArgTuple, 640 size_t... Indexes> 641 static typename WrappedHandlerReturn< 642 typename HandlerTraits<HandlerT>::ReturnType>::Type 643 unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder, 644 ArgTuple &Args, 645 llvm::index_sequence<Indexes...>) { 646 return run(Handler, Responder, std::move(std::get<Indexes>(Args))...); 647 } 648 }; 649 650 // Handler traits for free functions. 651 template <typename RetT, typename... ArgTs> 652 class HandlerTraits<RetT(*)(ArgTs...)> 653 : public HandlerTraits<RetT(ArgTs...)> {}; 654 655 // Handler traits for class methods (especially call operators for lambdas). 656 template <typename Class, typename RetT, typename... ArgTs> 657 class HandlerTraits<RetT (Class::*)(ArgTs...)> 658 : public HandlerTraits<RetT(ArgTs...)> {}; 659 660 // Handler traits for const class methods (especially call operators for 661 // lambdas). 662 template <typename Class, typename RetT, typename... ArgTs> 663 class HandlerTraits<RetT (Class::*)(ArgTs...) const> 664 : public HandlerTraits<RetT(ArgTs...)> {}; 665 666 // Utility to peel the Expected wrapper off a response handler error type. 667 template <typename HandlerT> class ResponseHandlerArg; 668 669 template <typename ArgT> class ResponseHandlerArg<Error(Expected<ArgT>)> { 670 public: 671 using ArgType = Expected<ArgT>; 672 using UnwrappedArgType = ArgT; 673 }; 674 675 template <typename ArgT> 676 class ResponseHandlerArg<ErrorSuccess(Expected<ArgT>)> { 677 public: 678 using ArgType = Expected<ArgT>; 679 using UnwrappedArgType = ArgT; 680 }; 681 682 template <> class ResponseHandlerArg<Error(Error)> { 683 public: 684 using ArgType = Error; 685 }; 686 687 template <> class ResponseHandlerArg<ErrorSuccess(Error)> { 688 public: 689 using ArgType = Error; 690 }; 691 692 // ResponseHandler represents a handler for a not-yet-received function call 693 // result. 694 template <typename ChannelT> class ResponseHandler { 695 public: 696 virtual ~ResponseHandler() {} 697 698 // Reads the function result off the wire and acts on it. The meaning of 699 // "act" will depend on how this method is implemented in any given 700 // ResponseHandler subclass but could, for example, mean running a 701 // user-specified handler or setting a promise value. 702 virtual Error handleResponse(ChannelT &C) = 0; 703 704 // Abandons this outstanding result. 705 virtual void abandon() = 0; 706 707 // Create an error instance representing an abandoned response. 708 static Error createAbandonedResponseError() { 709 return make_error<ResponseAbandoned>(); 710 } 711 }; 712 713 // ResponseHandler subclass for RPC functions with non-void returns. 714 template <typename ChannelT, typename FuncRetT, typename HandlerT> 715 class ResponseHandlerImpl : public ResponseHandler<ChannelT> { 716 public: 717 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} 718 719 // Handle the result by deserializing it from the channel then passing it 720 // to the user defined handler. 721 Error handleResponse(ChannelT &C) override { 722 using UnwrappedArgType = typename ResponseHandlerArg< 723 typename HandlerTraits<HandlerT>::Type>::UnwrappedArgType; 724 UnwrappedArgType Result; 725 if (auto Err = 726 SerializationTraits<ChannelT, FuncRetT, 727 UnwrappedArgType>::deserialize(C, Result)) 728 return Err; 729 if (auto Err = C.endReceiveMessage()) 730 return Err; 731 return Handler(std::move(Result)); 732 } 733 734 // Abandon this response by calling the handler with an 'abandoned response' 735 // error. 736 void abandon() override { 737 if (auto Err = Handler(this->createAbandonedResponseError())) { 738 // Handlers should not fail when passed an abandoned response error. 739 report_fatal_error(std::move(Err)); 740 } 741 } 742 743 private: 744 HandlerT Handler; 745 }; 746 747 // ResponseHandler subclass for RPC functions with void returns. 748 template <typename ChannelT, typename HandlerT> 749 class ResponseHandlerImpl<ChannelT, void, HandlerT> 750 : public ResponseHandler<ChannelT> { 751 public: 752 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} 753 754 // Handle the result (no actual value, just a notification that the function 755 // has completed on the remote end) by calling the user-defined handler with 756 // Error::success(). 757 Error handleResponse(ChannelT &C) override { 758 if (auto Err = C.endReceiveMessage()) 759 return Err; 760 return Handler(Error::success()); 761 } 762 763 // Abandon this response by calling the handler with an 'abandoned response' 764 // error. 765 void abandon() override { 766 if (auto Err = Handler(this->createAbandonedResponseError())) { 767 // Handlers should not fail when passed an abandoned response error. 768 report_fatal_error(std::move(Err)); 769 } 770 } 771 772 private: 773 HandlerT Handler; 774 }; 775 776 template <typename ChannelT, typename FuncRetT, typename HandlerT> 777 class ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT> 778 : public ResponseHandler<ChannelT> { 779 public: 780 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} 781 782 // Handle the result by deserializing it from the channel then passing it 783 // to the user defined handler. 784 Error handleResponse(ChannelT &C) override { 785 using HandlerArgType = typename ResponseHandlerArg< 786 typename HandlerTraits<HandlerT>::Type>::ArgType; 787 HandlerArgType Result((typename HandlerArgType::value_type())); 788 789 if (auto Err = 790 SerializationTraits<ChannelT, Expected<FuncRetT>, 791 HandlerArgType>::deserialize(C, Result)) 792 return Err; 793 if (auto Err = C.endReceiveMessage()) 794 return Err; 795 return Handler(std::move(Result)); 796 } 797 798 // Abandon this response by calling the handler with an 'abandoned response' 799 // error. 800 void abandon() override { 801 if (auto Err = Handler(this->createAbandonedResponseError())) { 802 // Handlers should not fail when passed an abandoned response error. 803 report_fatal_error(std::move(Err)); 804 } 805 } 806 807 private: 808 HandlerT Handler; 809 }; 810 811 template <typename ChannelT, typename HandlerT> 812 class ResponseHandlerImpl<ChannelT, Error, HandlerT> 813 : public ResponseHandler<ChannelT> { 814 public: 815 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} 816 817 // Handle the result by deserializing it from the channel then passing it 818 // to the user defined handler. 819 Error handleResponse(ChannelT &C) override { 820 Error Result = Error::success(); 821 if (auto Err = 822 SerializationTraits<ChannelT, Error, Error>::deserialize(C, Result)) 823 return Err; 824 if (auto Err = C.endReceiveMessage()) 825 return Err; 826 return Handler(std::move(Result)); 827 } 828 829 // Abandon this response by calling the handler with an 'abandoned response' 830 // error. 831 void abandon() override { 832 if (auto Err = Handler(this->createAbandonedResponseError())) { 833 // Handlers should not fail when passed an abandoned response error. 834 report_fatal_error(std::move(Err)); 835 } 836 } 837 838 private: 839 HandlerT Handler; 840 }; 841 842 // Create a ResponseHandler from a given user handler. 843 template <typename ChannelT, typename FuncRetT, typename HandlerT> 844 std::unique_ptr<ResponseHandler<ChannelT>> createResponseHandler(HandlerT H) { 845 return llvm::make_unique<ResponseHandlerImpl<ChannelT, FuncRetT, HandlerT>>( 846 std::move(H)); 847 } 848 849 // Helper for wrapping member functions up as functors. This is useful for 850 // installing methods as result handlers. 851 template <typename ClassT, typename RetT, typename... ArgTs> 852 class MemberFnWrapper { 853 public: 854 using MethodT = RetT (ClassT::*)(ArgTs...); 855 MemberFnWrapper(ClassT &Instance, MethodT Method) 856 : Instance(Instance), Method(Method) {} 857 RetT operator()(ArgTs &&... Args) { 858 return (Instance.*Method)(std::move(Args)...); 859 } 860 861 private: 862 ClassT &Instance; 863 MethodT Method; 864 }; 865 866 // Helper that provides a Functor for deserializing arguments. 867 template <typename... ArgTs> class ReadArgs { 868 public: 869 Error operator()() { return Error::success(); } 870 }; 871 872 template <typename ArgT, typename... ArgTs> 873 class ReadArgs<ArgT, ArgTs...> : public ReadArgs<ArgTs...> { 874 public: 875 ReadArgs(ArgT &Arg, ArgTs &... Args) 876 : ReadArgs<ArgTs...>(Args...), Arg(Arg) {} 877 878 Error operator()(ArgT &ArgVal, ArgTs &... ArgVals) { 879 this->Arg = std::move(ArgVal); 880 return ReadArgs<ArgTs...>::operator()(ArgVals...); 881 } 882 883 private: 884 ArgT &Arg; 885 }; 886 887 // Manage sequence numbers. 888 template <typename SequenceNumberT> class SequenceNumberManager { 889 public: 890 // Reset, making all sequence numbers available. 891 void reset() { 892 std::lock_guard<std::mutex> Lock(SeqNoLock); 893 NextSequenceNumber = 0; 894 FreeSequenceNumbers.clear(); 895 } 896 897 // Get the next available sequence number. Will re-use numbers that have 898 // been released. 899 SequenceNumberT getSequenceNumber() { 900 std::lock_guard<std::mutex> Lock(SeqNoLock); 901 if (FreeSequenceNumbers.empty()) 902 return NextSequenceNumber++; 903 auto SequenceNumber = FreeSequenceNumbers.back(); 904 FreeSequenceNumbers.pop_back(); 905 return SequenceNumber; 906 } 907 908 // Release a sequence number, making it available for re-use. 909 void releaseSequenceNumber(SequenceNumberT SequenceNumber) { 910 std::lock_guard<std::mutex> Lock(SeqNoLock); 911 FreeSequenceNumbers.push_back(SequenceNumber); 912 } 913 914 private: 915 std::mutex SeqNoLock; 916 SequenceNumberT NextSequenceNumber = 0; 917 std::vector<SequenceNumberT> FreeSequenceNumbers; 918 }; 919 920 // Checks that predicate P holds for each corresponding pair of type arguments 921 // from T1 and T2 tuple. 922 template <template <class, class> class P, typename T1Tuple, typename T2Tuple> 923 class RPCArgTypeCheckHelper; 924 925 template <template <class, class> class P> 926 class RPCArgTypeCheckHelper<P, std::tuple<>, std::tuple<>> { 927 public: 928 static const bool value = true; 929 }; 930 931 template <template <class, class> class P, typename T, typename... Ts, 932 typename U, typename... Us> 933 class RPCArgTypeCheckHelper<P, std::tuple<T, Ts...>, std::tuple<U, Us...>> { 934 public: 935 static const bool value = 936 P<T, U>::value && 937 RPCArgTypeCheckHelper<P, std::tuple<Ts...>, std::tuple<Us...>>::value; 938 }; 939 940 template <template <class, class> class P, typename T1Sig, typename T2Sig> 941 class RPCArgTypeCheck { 942 public: 943 using T1Tuple = typename FunctionArgsTuple<T1Sig>::Type; 944 using T2Tuple = typename FunctionArgsTuple<T2Sig>::Type; 945 946 static_assert(std::tuple_size<T1Tuple>::value >= 947 std::tuple_size<T2Tuple>::value, 948 "Too many arguments to RPC call"); 949 static_assert(std::tuple_size<T1Tuple>::value <= 950 std::tuple_size<T2Tuple>::value, 951 "Too few arguments to RPC call"); 952 953 static const bool value = RPCArgTypeCheckHelper<P, T1Tuple, T2Tuple>::value; 954 }; 955 956 template <typename ChannelT, typename WireT, typename ConcreteT> 957 class CanSerialize { 958 private: 959 using S = SerializationTraits<ChannelT, WireT, ConcreteT>; 960 961 template <typename T> 962 static std::true_type 963 check(typename std::enable_if< 964 std::is_same<decltype(T::serialize(std::declval<ChannelT &>(), 965 std::declval<const ConcreteT &>())), 966 Error>::value, 967 void *>::type); 968 969 template <typename> static std::false_type check(...); 970 971 public: 972 static const bool value = decltype(check<S>(0))::value; 973 }; 974 975 template <typename ChannelT, typename WireT, typename ConcreteT> 976 class CanDeserialize { 977 private: 978 using S = SerializationTraits<ChannelT, WireT, ConcreteT>; 979 980 template <typename T> 981 static std::true_type 982 check(typename std::enable_if< 983 std::is_same<decltype(T::deserialize(std::declval<ChannelT &>(), 984 std::declval<ConcreteT &>())), 985 Error>::value, 986 void *>::type); 987 988 template <typename> static std::false_type check(...); 989 990 public: 991 static const bool value = decltype(check<S>(0))::value; 992 }; 993 994 /// Contains primitive utilities for defining, calling and handling calls to 995 /// remote procedures. ChannelT is a bidirectional stream conforming to the 996 /// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure 997 /// identifier type that must be serializable on ChannelT, and SequenceNumberT 998 /// is an integral type that will be used to number in-flight function calls. 999 /// 1000 /// These utilities support the construction of very primitive RPC utilities. 1001 /// Their intent is to ensure correct serialization and deserialization of 1002 /// procedure arguments, and to keep the client and server's view of the API in 1003 /// sync. 1004 template <typename ImplT, typename ChannelT, typename FunctionIdT, 1005 typename SequenceNumberT> 1006 class RPCEndpointBase { 1007 protected: 1008 class OrcRPCInvalid : public Function<OrcRPCInvalid, void()> { 1009 public: 1010 static const char *getName() { return "__orc_rpc$invalid"; } 1011 }; 1012 1013 class OrcRPCResponse : public Function<OrcRPCResponse, void()> { 1014 public: 1015 static const char *getName() { return "__orc_rpc$response"; } 1016 }; 1017 1018 class OrcRPCNegotiate 1019 : public Function<OrcRPCNegotiate, FunctionIdT(std::string)> { 1020 public: 1021 static const char *getName() { return "__orc_rpc$negotiate"; } 1022 }; 1023 1024 // Helper predicate for testing for the presence of SerializeTraits 1025 // serializers. 1026 template <typename WireT, typename ConcreteT> 1027 class CanSerializeCheck : detail::CanSerialize<ChannelT, WireT, ConcreteT> { 1028 public: 1029 using detail::CanSerialize<ChannelT, WireT, ConcreteT>::value; 1030 1031 static_assert(value, "Missing serializer for argument (Can't serialize the " 1032 "first template type argument of CanSerializeCheck " 1033 "from the second)"); 1034 }; 1035 1036 // Helper predicate for testing for the presence of SerializeTraits 1037 // deserializers. 1038 template <typename WireT, typename ConcreteT> 1039 class CanDeserializeCheck 1040 : detail::CanDeserialize<ChannelT, WireT, ConcreteT> { 1041 public: 1042 using detail::CanDeserialize<ChannelT, WireT, ConcreteT>::value; 1043 1044 static_assert(value, "Missing deserializer for argument (Can't deserialize " 1045 "the second template type argument of " 1046 "CanDeserializeCheck from the first)"); 1047 }; 1048 1049 public: 1050 /// Construct an RPC instance on a channel. 1051 RPCEndpointBase(ChannelT &C, bool LazyAutoNegotiation) 1052 : C(C), LazyAutoNegotiation(LazyAutoNegotiation) { 1053 // Hold ResponseId in a special variable, since we expect Response to be 1054 // called relatively frequently, and want to avoid the map lookup. 1055 ResponseId = FnIdAllocator.getResponseId(); 1056 RemoteFunctionIds[OrcRPCResponse::getPrototype()] = ResponseId; 1057 1058 // Register the negotiate function id and handler. 1059 auto NegotiateId = FnIdAllocator.getNegotiateId(); 1060 RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId; 1061 Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>( 1062 [this](const std::string &Name) { return handleNegotiate(Name); }); 1063 } 1064 1065 1066 /// Negotiate a function id for Func with the other end of the channel. 1067 template <typename Func> Error negotiateFunction(bool Retry = false) { 1068 return getRemoteFunctionId<Func>(true, Retry).takeError(); 1069 } 1070 1071 /// Append a call Func, does not call send on the channel. 1072 /// The first argument specifies a user-defined handler to be run when the 1073 /// function returns. The handler should take an Expected<Func::ReturnType>, 1074 /// or an Error (if Func::ReturnType is void). The handler will be called 1075 /// with an error if the return value is abandoned due to a channel error. 1076 template <typename Func, typename HandlerT, typename... ArgTs> 1077 Error appendCallAsync(HandlerT Handler, const ArgTs &... Args) { 1078 1079 static_assert( 1080 detail::RPCArgTypeCheck<CanSerializeCheck, typename Func::Type, 1081 void(ArgTs...)>::value, 1082 ""); 1083 1084 // Look up the function ID. 1085 FunctionIdT FnId; 1086 if (auto FnIdOrErr = getRemoteFunctionId<Func>(LazyAutoNegotiation, false)) 1087 FnId = *FnIdOrErr; 1088 else { 1089 // Negotiation failed. Notify the handler then return the negotiate-failed 1090 // error. 1091 cantFail(Handler(make_error<ResponseAbandoned>())); 1092 return FnIdOrErr.takeError(); 1093 } 1094 1095 SequenceNumberT SeqNo; // initialized in locked scope below. 1096 { 1097 // Lock the pending responses map and sequence number manager. 1098 std::lock_guard<std::mutex> Lock(ResponsesMutex); 1099 1100 // Allocate a sequence number. 1101 SeqNo = SequenceNumberMgr.getSequenceNumber(); 1102 assert(!PendingResponses.count(SeqNo) && 1103 "Sequence number already allocated"); 1104 1105 // Install the user handler. 1106 PendingResponses[SeqNo] = 1107 detail::createResponseHandler<ChannelT, typename Func::ReturnType>( 1108 std::move(Handler)); 1109 } 1110 1111 // Open the function call message. 1112 if (auto Err = C.startSendMessage(FnId, SeqNo)) { 1113 abandonPendingResponses(); 1114 return Err; 1115 } 1116 1117 // Serialize the call arguments. 1118 if (auto Err = detail::HandlerTraits<typename Func::Type>::serializeArgs( 1119 C, Args...)) { 1120 abandonPendingResponses(); 1121 return Err; 1122 } 1123 1124 // Close the function call messagee. 1125 if (auto Err = C.endSendMessage()) { 1126 abandonPendingResponses(); 1127 return Err; 1128 } 1129 1130 return Error::success(); 1131 } 1132 1133 Error sendAppendedCalls() { return C.send(); }; 1134 1135 template <typename Func, typename HandlerT, typename... ArgTs> 1136 Error callAsync(HandlerT Handler, const ArgTs &... Args) { 1137 if (auto Err = appendCallAsync<Func>(std::move(Handler), Args...)) 1138 return Err; 1139 return C.send(); 1140 } 1141 1142 /// Handle one incoming call. 1143 Error handleOne() { 1144 FunctionIdT FnId; 1145 SequenceNumberT SeqNo; 1146 if (auto Err = C.startReceiveMessage(FnId, SeqNo)) { 1147 abandonPendingResponses(); 1148 return Err; 1149 } 1150 if (FnId == ResponseId) 1151 return handleResponse(SeqNo); 1152 auto I = Handlers.find(FnId); 1153 if (I != Handlers.end()) 1154 return I->second(C, SeqNo); 1155 1156 // else: No handler found. Report error to client? 1157 return make_error<BadFunctionCall<FunctionIdT, SequenceNumberT>>(FnId, 1158 SeqNo); 1159 } 1160 1161 /// Helper for handling setter procedures - this method returns a functor that 1162 /// sets the variables referred to by Args... to values deserialized from the 1163 /// channel. 1164 /// E.g. 1165 /// 1166 /// typedef Function<0, bool, int> Func1; 1167 /// 1168 /// ... 1169 /// bool B; 1170 /// int I; 1171 /// if (auto Err = expect<Func1>(Channel, readArgs(B, I))) 1172 /// /* Handle Args */ ; 1173 /// 1174 template <typename... ArgTs> 1175 static detail::ReadArgs<ArgTs...> readArgs(ArgTs &... Args) { 1176 return detail::ReadArgs<ArgTs...>(Args...); 1177 } 1178 1179 /// Abandon all outstanding result handlers. 1180 /// 1181 /// This will call all currently registered result handlers to receive an 1182 /// "abandoned" error as their argument. This is used internally by the RPC 1183 /// in error situations, but can also be called directly by clients who are 1184 /// disconnecting from the remote and don't or can't expect responses to their 1185 /// outstanding calls. (Especially for outstanding blocking calls, calling 1186 /// this function may be necessary to avoid dead threads). 1187 void abandonPendingResponses() { 1188 // Lock the pending responses map and sequence number manager. 1189 std::lock_guard<std::mutex> Lock(ResponsesMutex); 1190 1191 for (auto &KV : PendingResponses) 1192 KV.second->abandon(); 1193 PendingResponses.clear(); 1194 SequenceNumberMgr.reset(); 1195 } 1196 1197 /// Remove the handler for the given function. 1198 /// A handler must currently be registered for this function. 1199 template <typename Func> 1200 void removeHandler() { 1201 auto IdItr = LocalFunctionIds.find(Func::getPrototype()); 1202 assert(IdItr != LocalFunctionIds.end() && 1203 "Function does not have a registered handler"); 1204 auto HandlerItr = Handlers.find(IdItr->second); 1205 assert(HandlerItr != Handlers.end() && 1206 "Function does not have a registered handler"); 1207 Handlers.erase(HandlerItr); 1208 } 1209 1210 /// Clear all handlers. 1211 void clearHandlers() { 1212 Handlers.clear(); 1213 } 1214 1215 protected: 1216 1217 FunctionIdT getInvalidFunctionId() const { 1218 return FnIdAllocator.getInvalidId(); 1219 } 1220 1221 /// Add the given handler to the handler map and make it available for 1222 /// autonegotiation and execution. 1223 template <typename Func, typename HandlerT> 1224 void addHandlerImpl(HandlerT Handler) { 1225 1226 static_assert(detail::RPCArgTypeCheck< 1227 CanDeserializeCheck, typename Func::Type, 1228 typename detail::HandlerTraits<HandlerT>::Type>::value, 1229 ""); 1230 1231 FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); 1232 LocalFunctionIds[Func::getPrototype()] = NewFnId; 1233 Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler)); 1234 } 1235 1236 template <typename Func, typename HandlerT> 1237 void addAsyncHandlerImpl(HandlerT Handler) { 1238 1239 static_assert(detail::RPCArgTypeCheck< 1240 CanDeserializeCheck, typename Func::Type, 1241 typename detail::AsyncHandlerTraits< 1242 typename detail::HandlerTraits<HandlerT>::Type 1243 >::Type>::value, 1244 ""); 1245 1246 FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); 1247 LocalFunctionIds[Func::getPrototype()] = NewFnId; 1248 Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler)); 1249 } 1250 1251 Error handleResponse(SequenceNumberT SeqNo) { 1252 using Handler = typename decltype(PendingResponses)::mapped_type; 1253 Handler PRHandler; 1254 1255 { 1256 // Lock the pending responses map and sequence number manager. 1257 std::unique_lock<std::mutex> Lock(ResponsesMutex); 1258 auto I = PendingResponses.find(SeqNo); 1259 1260 if (I != PendingResponses.end()) { 1261 PRHandler = std::move(I->second); 1262 PendingResponses.erase(I); 1263 SequenceNumberMgr.releaseSequenceNumber(SeqNo); 1264 } else { 1265 // Unlock the pending results map to prevent recursive lock. 1266 Lock.unlock(); 1267 abandonPendingResponses(); 1268 return make_error< 1269 InvalidSequenceNumberForResponse<SequenceNumberT>>(SeqNo); 1270 } 1271 } 1272 1273 assert(PRHandler && 1274 "If we didn't find a response handler we should have bailed out"); 1275 1276 if (auto Err = PRHandler->handleResponse(C)) { 1277 abandonPendingResponses(); 1278 return Err; 1279 } 1280 1281 return Error::success(); 1282 } 1283 1284 FunctionIdT handleNegotiate(const std::string &Name) { 1285 auto I = LocalFunctionIds.find(Name); 1286 if (I == LocalFunctionIds.end()) 1287 return getInvalidFunctionId(); 1288 return I->second; 1289 } 1290 1291 // Find the remote FunctionId for the given function. 1292 template <typename Func> 1293 Expected<FunctionIdT> getRemoteFunctionId(bool NegotiateIfNotInMap, 1294 bool NegotiateIfInvalid) { 1295 bool DoNegotiate; 1296 1297 // Check if we already have a function id... 1298 auto I = RemoteFunctionIds.find(Func::getPrototype()); 1299 if (I != RemoteFunctionIds.end()) { 1300 // If it's valid there's nothing left to do. 1301 if (I->second != getInvalidFunctionId()) 1302 return I->second; 1303 DoNegotiate = NegotiateIfInvalid; 1304 } else 1305 DoNegotiate = NegotiateIfNotInMap; 1306 1307 // We don't have a function id for Func yet, but we're allowed to try to 1308 // negotiate one. 1309 if (DoNegotiate) { 1310 auto &Impl = static_cast<ImplT &>(*this); 1311 if (auto RemoteIdOrErr = 1312 Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) { 1313 RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr; 1314 if (*RemoteIdOrErr == getInvalidFunctionId()) 1315 return make_error<CouldNotNegotiate>(Func::getPrototype()); 1316 return *RemoteIdOrErr; 1317 } else 1318 return RemoteIdOrErr.takeError(); 1319 } 1320 1321 // No key was available in the map and we weren't allowed to try to 1322 // negotiate one, so return an unknown function error. 1323 return make_error<CouldNotNegotiate>(Func::getPrototype()); 1324 } 1325 1326 using WrappedHandlerFn = std::function<Error(ChannelT &, SequenceNumberT)>; 1327 1328 // Wrap the given user handler in the necessary argument-deserialization code, 1329 // result-serialization code, and call to the launch policy (if present). 1330 template <typename Func, typename HandlerT> 1331 WrappedHandlerFn wrapHandler(HandlerT Handler) { 1332 return [this, Handler](ChannelT &Channel, 1333 SequenceNumberT SeqNo) mutable -> Error { 1334 // Start by deserializing the arguments. 1335 using ArgsTuple = 1336 typename detail::FunctionArgsTuple< 1337 typename detail::HandlerTraits<HandlerT>::Type>::Type; 1338 auto Args = std::make_shared<ArgsTuple>(); 1339 1340 if (auto Err = 1341 detail::HandlerTraits<typename Func::Type>::deserializeArgs( 1342 Channel, *Args)) 1343 return Err; 1344 1345 // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning 1346 // for RPCArgs. Void cast RPCArgs to work around this for now. 1347 // FIXME: Remove this workaround once we can assume a working GCC version. 1348 (void)Args; 1349 1350 // End receieve message, unlocking the channel for reading. 1351 if (auto Err = Channel.endReceiveMessage()) 1352 return Err; 1353 1354 using HTraits = detail::HandlerTraits<HandlerT>; 1355 using FuncReturn = typename Func::ReturnType; 1356 return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo, 1357 HTraits::unpackAndRun(Handler, *Args)); 1358 }; 1359 } 1360 1361 // Wrap the given user handler in the necessary argument-deserialization code, 1362 // result-serialization code, and call to the launch policy (if present). 1363 template <typename Func, typename HandlerT> 1364 WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) { 1365 return [this, Handler](ChannelT &Channel, 1366 SequenceNumberT SeqNo) mutable -> Error { 1367 // Start by deserializing the arguments. 1368 using AHTraits = detail::AsyncHandlerTraits< 1369 typename detail::HandlerTraits<HandlerT>::Type>; 1370 using ArgsTuple = 1371 typename detail::FunctionArgsTuple<typename AHTraits::Type>::Type; 1372 auto Args = std::make_shared<ArgsTuple>(); 1373 1374 if (auto Err = 1375 detail::HandlerTraits<typename Func::Type>::deserializeArgs( 1376 Channel, *Args)) 1377 return Err; 1378 1379 // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning 1380 // for RPCArgs. Void cast RPCArgs to work around this for now. 1381 // FIXME: Remove this workaround once we can assume a working GCC version. 1382 (void)Args; 1383 1384 // End receieve message, unlocking the channel for reading. 1385 if (auto Err = Channel.endReceiveMessage()) 1386 return Err; 1387 1388 using HTraits = detail::HandlerTraits<HandlerT>; 1389 using FuncReturn = typename Func::ReturnType; 1390 auto Responder = 1391 [this, SeqNo](typename AHTraits::ResultType RetVal) -> Error { 1392 return detail::respond<FuncReturn>(C, ResponseId, SeqNo, 1393 std::move(RetVal)); 1394 }; 1395 1396 return HTraits::unpackAndRunAsync(Handler, Responder, *Args); 1397 }; 1398 } 1399 1400 ChannelT &C; 1401 1402 bool LazyAutoNegotiation; 1403 1404 RPCFunctionIdAllocator<FunctionIdT> FnIdAllocator; 1405 1406 FunctionIdT ResponseId; 1407 std::map<std::string, FunctionIdT> LocalFunctionIds; 1408 std::map<const char *, FunctionIdT> RemoteFunctionIds; 1409 1410 std::map<FunctionIdT, WrappedHandlerFn> Handlers; 1411 1412 std::mutex ResponsesMutex; 1413 detail::SequenceNumberManager<SequenceNumberT> SequenceNumberMgr; 1414 std::map<SequenceNumberT, std::unique_ptr<detail::ResponseHandler<ChannelT>>> 1415 PendingResponses; 1416 }; 1417 1418 } // end namespace detail 1419 1420 template <typename ChannelT, typename FunctionIdT = uint32_t, 1421 typename SequenceNumberT = uint32_t> 1422 class MultiThreadedRPCEndpoint 1423 : public detail::RPCEndpointBase< 1424 MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, 1425 ChannelT, FunctionIdT, SequenceNumberT> { 1426 private: 1427 using BaseClass = 1428 detail::RPCEndpointBase< 1429 MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, 1430 ChannelT, FunctionIdT, SequenceNumberT>; 1431 1432 public: 1433 MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation) 1434 : BaseClass(C, LazyAutoNegotiation) {} 1435 1436 /// Add a handler for the given RPC function. 1437 /// This installs the given handler functor for the given RPC Function, and 1438 /// makes the RPC function available for negotiation/calling from the remote. 1439 template <typename Func, typename HandlerT> 1440 void addHandler(HandlerT Handler) { 1441 return this->template addHandlerImpl<Func>(std::move(Handler)); 1442 } 1443 1444 /// Add a class-method as a handler. 1445 template <typename Func, typename ClassT, typename RetT, typename... ArgTs> 1446 void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { 1447 addHandler<Func>( 1448 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); 1449 } 1450 1451 template <typename Func, typename HandlerT> 1452 void addAsyncHandler(HandlerT Handler) { 1453 return this->template addAsyncHandlerImpl<Func>(std::move(Handler)); 1454 } 1455 1456 /// Add a class-method as a handler. 1457 template <typename Func, typename ClassT, typename RetT, typename... ArgTs> 1458 void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { 1459 addAsyncHandler<Func>( 1460 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); 1461 } 1462 1463 /// Return type for non-blocking call primitives. 1464 template <typename Func> 1465 using NonBlockingCallResult = typename detail::ResultTraits< 1466 typename Func::ReturnType>::ReturnFutureType; 1467 1468 /// Call Func on Channel C. Does not block, does not call send. Returns a pair 1469 /// of a future result and the sequence number assigned to the result. 1470 /// 1471 /// This utility function is primarily used for single-threaded mode support, 1472 /// where the sequence number can be used to wait for the corresponding 1473 /// result. In multi-threaded mode the appendCallNB method, which does not 1474 /// return the sequence numeber, should be preferred. 1475 template <typename Func, typename... ArgTs> 1476 Expected<NonBlockingCallResult<Func>> appendCallNB(const ArgTs &... Args) { 1477 using RTraits = detail::ResultTraits<typename Func::ReturnType>; 1478 using ErrorReturn = typename RTraits::ErrorReturnType; 1479 using ErrorReturnPromise = typename RTraits::ReturnPromiseType; 1480 1481 // FIXME: Stack allocate and move this into the handler once LLVM builds 1482 // with C++14. 1483 auto Promise = std::make_shared<ErrorReturnPromise>(); 1484 auto FutureResult = Promise->get_future(); 1485 1486 if (auto Err = this->template appendCallAsync<Func>( 1487 [Promise](ErrorReturn RetOrErr) { 1488 Promise->set_value(std::move(RetOrErr)); 1489 return Error::success(); 1490 }, 1491 Args...)) { 1492 RTraits::consumeAbandoned(FutureResult.get()); 1493 return std::move(Err); 1494 } 1495 return std::move(FutureResult); 1496 } 1497 1498 /// The same as appendCallNBWithSeq, except that it calls C.send() to 1499 /// flush the channel after serializing the call. 1500 template <typename Func, typename... ArgTs> 1501 Expected<NonBlockingCallResult<Func>> callNB(const ArgTs &... Args) { 1502 auto Result = appendCallNB<Func>(Args...); 1503 if (!Result) 1504 return Result; 1505 if (auto Err = this->C.send()) { 1506 this->abandonPendingResponses(); 1507 detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( 1508 std::move(Result->get())); 1509 return std::move(Err); 1510 } 1511 return Result; 1512 } 1513 1514 /// Call Func on Channel C. Blocks waiting for a result. Returns an Error 1515 /// for void functions or an Expected<T> for functions returning a T. 1516 /// 1517 /// This function is for use in threaded code where another thread is 1518 /// handling responses and incoming calls. 1519 template <typename Func, typename... ArgTs, 1520 typename AltRetT = typename Func::ReturnType> 1521 typename detail::ResultTraits<AltRetT>::ErrorReturnType 1522 callB(const ArgTs &... Args) { 1523 if (auto FutureResOrErr = callNB<Func>(Args...)) 1524 return FutureResOrErr->get(); 1525 else 1526 return FutureResOrErr.takeError(); 1527 } 1528 1529 /// Handle incoming RPC calls. 1530 Error handlerLoop() { 1531 while (true) 1532 if (auto Err = this->handleOne()) 1533 return Err; 1534 return Error::success(); 1535 } 1536 }; 1537 1538 template <typename ChannelT, typename FunctionIdT = uint32_t, 1539 typename SequenceNumberT = uint32_t> 1540 class SingleThreadedRPCEndpoint 1541 : public detail::RPCEndpointBase< 1542 SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, 1543 ChannelT, FunctionIdT, SequenceNumberT> { 1544 private: 1545 using BaseClass = 1546 detail::RPCEndpointBase< 1547 SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, 1548 ChannelT, FunctionIdT, SequenceNumberT>; 1549 1550 public: 1551 SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation) 1552 : BaseClass(C, LazyAutoNegotiation) {} 1553 1554 template <typename Func, typename HandlerT> 1555 void addHandler(HandlerT Handler) { 1556 return this->template addHandlerImpl<Func>(std::move(Handler)); 1557 } 1558 1559 template <typename Func, typename ClassT, typename RetT, typename... ArgTs> 1560 void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { 1561 addHandler<Func>( 1562 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); 1563 } 1564 1565 template <typename Func, typename HandlerT> 1566 void addAsyncHandler(HandlerT Handler) { 1567 return this->template addAsyncHandlerImpl<Func>(std::move(Handler)); 1568 } 1569 1570 /// Add a class-method as a handler. 1571 template <typename Func, typename ClassT, typename RetT, typename... ArgTs> 1572 void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { 1573 addAsyncHandler<Func>( 1574 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); 1575 } 1576 1577 template <typename Func, typename... ArgTs, 1578 typename AltRetT = typename Func::ReturnType> 1579 typename detail::ResultTraits<AltRetT>::ErrorReturnType 1580 callB(const ArgTs &... Args) { 1581 bool ReceivedResponse = false; 1582 using ResultType = typename detail::ResultTraits<AltRetT>::ErrorReturnType; 1583 auto Result = detail::ResultTraits<AltRetT>::createBlankErrorReturnValue(); 1584 1585 // We have to 'Check' result (which we know is in a success state at this 1586 // point) so that it can be overwritten in the async handler. 1587 (void)!!Result; 1588 1589 if (auto Err = this->template appendCallAsync<Func>( 1590 [&](ResultType R) { 1591 Result = std::move(R); 1592 ReceivedResponse = true; 1593 return Error::success(); 1594 }, 1595 Args...)) { 1596 detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( 1597 std::move(Result)); 1598 return std::move(Err); 1599 } 1600 1601 while (!ReceivedResponse) { 1602 if (auto Err = this->handleOne()) { 1603 detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( 1604 std::move(Result)); 1605 return std::move(Err); 1606 } 1607 } 1608 1609 return Result; 1610 } 1611 }; 1612 1613 /// Asynchronous dispatch for a function on an RPC endpoint. 1614 template <typename RPCClass, typename Func> 1615 class RPCAsyncDispatch { 1616 public: 1617 RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {} 1618 1619 template <typename HandlerT, typename... ArgTs> 1620 Error operator()(HandlerT Handler, const ArgTs &... Args) const { 1621 return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...); 1622 } 1623 1624 private: 1625 RPCClass &Endpoint; 1626 }; 1627 1628 /// Construct an asynchronous dispatcher from an RPC endpoint and a Func. 1629 template <typename Func, typename RPCEndpointT> 1630 RPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) { 1631 return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint); 1632 } 1633 1634 /// \brief Allows a set of asynchrounous calls to be dispatched, and then 1635 /// waited on as a group. 1636 class ParallelCallGroup { 1637 public: 1638 1639 ParallelCallGroup() = default; 1640 ParallelCallGroup(const ParallelCallGroup &) = delete; 1641 ParallelCallGroup &operator=(const ParallelCallGroup &) = delete; 1642 1643 /// \brief Make as asynchronous call. 1644 template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs> 1645 Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler, 1646 const ArgTs &... Args) { 1647 // Increment the count of outstanding calls. This has to happen before 1648 // we invoke the call, as the handler may (depending on scheduling) 1649 // be run immediately on another thread, and we don't want the decrement 1650 // in the wrapped handler below to run before the increment. 1651 { 1652 std::unique_lock<std::mutex> Lock(M); 1653 ++NumOutstandingCalls; 1654 } 1655 1656 // Wrap the user handler in a lambda that will decrement the 1657 // outstanding calls count, then poke the condition variable. 1658 using ArgType = typename detail::ResponseHandlerArg< 1659 typename detail::HandlerTraits<HandlerT>::Type>::ArgType; 1660 // FIXME: Move handler into wrapped handler once we have C++14. 1661 auto WrappedHandler = [this, Handler](ArgType Arg) { 1662 auto Err = Handler(std::move(Arg)); 1663 std::unique_lock<std::mutex> Lock(M); 1664 --NumOutstandingCalls; 1665 CV.notify_all(); 1666 return Err; 1667 }; 1668 1669 return AsyncDispatch(std::move(WrappedHandler), Args...); 1670 } 1671 1672 /// \brief Blocks until all calls have been completed and their return value 1673 /// handlers run. 1674 void wait() { 1675 std::unique_lock<std::mutex> Lock(M); 1676 while (NumOutstandingCalls > 0) 1677 CV.wait(Lock); 1678 } 1679 1680 private: 1681 std::mutex M; 1682 std::condition_variable CV; 1683 uint32_t NumOutstandingCalls = 0; 1684 }; 1685 1686 /// @brief Convenience class for grouping RPC Functions into APIs that can be 1687 /// negotiated as a block. 1688 /// 1689 template <typename... Funcs> 1690 class APICalls { 1691 public: 1692 1693 /// @brief Test whether this API contains Function F. 1694 template <typename F> 1695 class Contains { 1696 public: 1697 static const bool value = false; 1698 }; 1699 1700 /// @brief Negotiate all functions in this API. 1701 template <typename RPCEndpoint> 1702 static Error negotiate(RPCEndpoint &R) { 1703 return Error::success(); 1704 } 1705 }; 1706 1707 template <typename Func, typename... Funcs> 1708 class APICalls<Func, Funcs...> { 1709 public: 1710 1711 template <typename F> 1712 class Contains { 1713 public: 1714 static const bool value = std::is_same<F, Func>::value | 1715 APICalls<Funcs...>::template Contains<F>::value; 1716 }; 1717 1718 template <typename RPCEndpoint> 1719 static Error negotiate(RPCEndpoint &R) { 1720 if (auto Err = R.template negotiateFunction<Func>()) 1721 return Err; 1722 return APICalls<Funcs...>::negotiate(R); 1723 } 1724 1725 }; 1726 1727 template <typename... InnerFuncs, typename... Funcs> 1728 class APICalls<APICalls<InnerFuncs...>, Funcs...> { 1729 public: 1730 1731 template <typename F> 1732 class Contains { 1733 public: 1734 static const bool value = 1735 APICalls<InnerFuncs...>::template Contains<F>::value | 1736 APICalls<Funcs...>::template Contains<F>::value; 1737 }; 1738 1739 template <typename RPCEndpoint> 1740 static Error negotiate(RPCEndpoint &R) { 1741 if (auto Err = APICalls<InnerFuncs...>::negotiate(R)) 1742 return Err; 1743 return APICalls<Funcs...>::negotiate(R); 1744 } 1745 1746 }; 1747 1748 } // end namespace rpc 1749 } // end namespace orc 1750 } // end namespace llvm 1751 1752 #endif 1753