Home | History | Annotate | Download | only in Orc
      1 //===- llvm/ExecutionEngine/Orc/RawByteChannel.h ----------------*- C++ -*-===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 
     10 #ifndef LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H
     11 #define LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H
     12 
     13 #include "OrcError.h"
     14 #include "RPCSerialization.h"
     15 #include "llvm/ADT/ArrayRef.h"
     16 #include "llvm/ADT/STLExtras.h"
     17 #include "llvm/ADT/StringRef.h"
     18 #include "llvm/Support/Endian.h"
     19 #include "llvm/Support/Error.h"
     20 #include <cstddef>
     21 #include <cstdint>
     22 #include <mutex>
     23 #include <string>
     24 #include <tuple>
     25 #include <type_traits>
     26 #include <vector>
     27 
     28 namespace llvm {
     29 namespace orc {
     30 namespace rpc {
     31 
     32 /// Interface for byte-streams to be used with RPC.
     33 class RawByteChannel {
     34 public:
     35   virtual ~RawByteChannel() {}
     36 
     37   /// Read Size bytes from the stream into *Dst.
     38   virtual Error readBytes(char *Dst, unsigned Size) = 0;
     39 
     40   /// Read size bytes from *Src and append them to the stream.
     41   virtual Error appendBytes(const char *Src, unsigned Size) = 0;
     42 
     43   /// Flush the stream if possible.
     44   virtual Error send() = 0;
     45 
     46   /// Notify the channel that we're starting a message send.
     47   /// Locks the channel for writing.
     48   template <typename FunctionIdT, typename SequenceIdT>
     49   Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) {
     50     writeLock.lock();
     51     if (auto Err = serializeSeq(*this, FnId, SeqNo)) {
     52       writeLock.unlock();
     53       return Err;
     54     }
     55     return Error::success();
     56   }
     57 
     58   /// Notify the channel that we're ending a message send.
     59   /// Unlocks the channel for writing.
     60   Error endSendMessage() {
     61     writeLock.unlock();
     62     return Error::success();
     63   }
     64 
     65   /// Notify the channel that we're starting a message receive.
     66   /// Locks the channel for reading.
     67   template <typename FunctionIdT, typename SequenceNumberT>
     68   Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) {
     69     readLock.lock();
     70     if (auto Err = deserializeSeq(*this, FnId, SeqNo)) {
     71       readLock.unlock();
     72       return Err;
     73     }
     74     return Error::success();
     75   }
     76 
     77   /// Notify the channel that we're ending a message receive.
     78   /// Unlocks the channel for reading.
     79   Error endReceiveMessage() {
     80     readLock.unlock();
     81     return Error::success();
     82   }
     83 
     84   /// Get the lock for stream reading.
     85   std::mutex &getReadLock() { return readLock; }
     86 
     87   /// Get the lock for stream writing.
     88   std::mutex &getWriteLock() { return writeLock; }
     89 
     90 private:
     91   std::mutex readLock, writeLock;
     92 };
     93 
     94 template <typename ChannelT, typename T>
     95 class SerializationTraits<
     96     ChannelT, T, T,
     97     typename std::enable_if<
     98         std::is_base_of<RawByteChannel, ChannelT>::value &&
     99         (std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value ||
    100          std::is_same<T, uint16_t>::value || std::is_same<T, int16_t>::value ||
    101          std::is_same<T, uint32_t>::value || std::is_same<T, int32_t>::value ||
    102          std::is_same<T, uint64_t>::value || std::is_same<T, int64_t>::value ||
    103          std::is_same<T, char>::value)>::type> {
    104 public:
    105   static Error serialize(ChannelT &C, T V) {
    106     support::endian::byte_swap<T, support::big>(V);
    107     return C.appendBytes(reinterpret_cast<const char *>(&V), sizeof(T));
    108   };
    109 
    110   static Error deserialize(ChannelT &C, T &V) {
    111     if (auto Err = C.readBytes(reinterpret_cast<char *>(&V), sizeof(T)))
    112       return Err;
    113     support::endian::byte_swap<T, support::big>(V);
    114     return Error::success();
    115   };
    116 };
    117 
    118 template <typename ChannelT>
    119 class SerializationTraits<ChannelT, bool, bool,
    120                           typename std::enable_if<std::is_base_of<
    121                               RawByteChannel, ChannelT>::value>::type> {
    122 public:
    123   static Error serialize(ChannelT &C, bool V) {
    124     uint8_t Tmp = V ? 1 : 0;
    125     if (auto Err =
    126           C.appendBytes(reinterpret_cast<const char *>(&Tmp), 1))
    127       return Err;
    128     return Error::success();
    129   }
    130 
    131   static Error deserialize(ChannelT &C, bool &V) {
    132     uint8_t Tmp = 0;
    133     if (auto Err = C.readBytes(reinterpret_cast<char *>(&Tmp), 1))
    134       return Err;
    135     V = Tmp != 0;
    136     return Error::success();
    137   }
    138 };
    139 
    140 template <typename ChannelT>
    141 class SerializationTraits<ChannelT, std::string, StringRef,
    142                           typename std::enable_if<std::is_base_of<
    143                               RawByteChannel, ChannelT>::value>::type> {
    144 public:
    145   /// RPC channel serialization for std::strings.
    146   static Error serialize(RawByteChannel &C, StringRef S) {
    147     if (auto Err = serializeSeq(C, static_cast<uint64_t>(S.size())))
    148       return Err;
    149     return C.appendBytes((const char *)S.data(), S.size());
    150   }
    151 };
    152 
    153 template <typename ChannelT, typename T>
    154 class SerializationTraits<ChannelT, std::string, T,
    155                           typename std::enable_if<
    156                             std::is_base_of<RawByteChannel, ChannelT>::value &&
    157                             (std::is_same<T, const char*>::value ||
    158                              std::is_same<T, char*>::value)>::type> {
    159 public:
    160   static Error serialize(RawByteChannel &C, const char *S) {
    161     return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C,
    162                                                                             S);
    163   }
    164 };
    165 
    166 template <typename ChannelT>
    167 class SerializationTraits<ChannelT, std::string, std::string,
    168                           typename std::enable_if<std::is_base_of<
    169                               RawByteChannel, ChannelT>::value>::type> {
    170 public:
    171   /// RPC channel serialization for std::strings.
    172   static Error serialize(RawByteChannel &C, const std::string &S) {
    173     return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C,
    174                                                                             S);
    175   }
    176 
    177   /// RPC channel deserialization for std::strings.
    178   static Error deserialize(RawByteChannel &C, std::string &S) {
    179     uint64_t Count = 0;
    180     if (auto Err = deserializeSeq(C, Count))
    181       return Err;
    182     S.resize(Count);
    183     return C.readBytes(&S[0], Count);
    184   }
    185 };
    186 
    187 } // end namespace rpc
    188 } // end namespace orc
    189 } // end namespace llvm
    190 
    191 #endif // LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H
    192