Home | History | Annotate | Download | only in libpdx_uds
      1 #include <uds/client_channel.h>
      2 
      3 #include <sys/socket.h>
      4 
      5 #include <algorithm>
      6 #include <limits>
      7 #include <random>
      8 #include <thread>
      9 
     10 #include <gmock/gmock.h>
     11 #include <gtest/gtest.h>
     12 
     13 #include <pdx/client.h>
     14 #include <pdx/rpc/remote_method.h>
     15 #include <pdx/service.h>
     16 #include <pdx/service_dispatcher.h>
     17 
     18 #include <uds/client_channel_factory.h>
     19 #include <uds/service_endpoint.h>
     20 
     21 using testing::Return;
     22 using testing::_;
     23 
     24 using android::pdx::ClientBase;
     25 using android::pdx::LocalChannelHandle;
     26 using android::pdx::LocalHandle;
     27 using android::pdx::Message;
     28 using android::pdx::ServiceBase;
     29 using android::pdx::ServiceDispatcher;
     30 using android::pdx::Status;
     31 using android::pdx::rpc::DispatchRemoteMethod;
     32 using android::pdx::uds::ClientChannel;
     33 using android::pdx::uds::ClientChannelFactory;
     34 using android::pdx::uds::Endpoint;
     35 
     36 namespace {
     37 
     38 struct TestProtocol {
     39   using DataType = int8_t;
     40   enum {
     41     kOpSum = 0,
     42   };
     43   PDX_REMOTE_METHOD(Sum, kOpSum, int64_t(const std::vector<DataType>&));
     44 };
     45 
     46 class TestService : public ServiceBase<TestService> {
     47  public:
     48   TestService(std::unique_ptr<Endpoint> endpoint)
     49       : ServiceBase{"TestService", std::move(endpoint)} {}
     50 
     51   Status<void> HandleMessage(Message& message) override {
     52     switch (message.GetOp()) {
     53       case TestProtocol::kOpSum:
     54         DispatchRemoteMethod<TestProtocol::Sum>(*this, &TestService::OnSum,
     55                                                 message);
     56         return {};
     57 
     58       default:
     59         return Service::HandleMessage(message);
     60     }
     61   }
     62 
     63   int64_t OnSum(Message& /*message*/,
     64                 const std::vector<TestProtocol::DataType>& data) {
     65     return std::accumulate(data.begin(), data.end(), int64_t{0});
     66   }
     67 };
     68 
     69 class TestClient : public ClientBase<TestClient> {
     70  public:
     71   using ClientBase::ClientBase;
     72 
     73   int64_t Sum(const std::vector<TestProtocol::DataType>& data) {
     74     auto status = InvokeRemoteMethod<TestProtocol::Sum>(data);
     75     return status ? status.get() : -1;
     76   }
     77 };
     78 
     79 class TestServiceRunner {
     80  public:
     81   TestServiceRunner(LocalHandle channel_socket) {
     82     auto endpoint = Endpoint::CreateFromSocketFd(LocalHandle{});
     83     endpoint->RegisterNewChannelForTests(std::move(channel_socket));
     84     service_ = TestService::Create(std::move(endpoint));
     85     dispatcher_ = ServiceDispatcher::Create();
     86     dispatcher_->AddService(service_);
     87     dispatch_thread_ = std::thread(
     88         std::bind(&ServiceDispatcher::EnterDispatchLoop, dispatcher_.get()));
     89   }
     90 
     91   ~TestServiceRunner() {
     92     dispatcher_->SetCanceled(true);
     93     dispatch_thread_.join();
     94     dispatcher_->RemoveService(service_);
     95   }
     96 
     97  private:
     98   std::shared_ptr<TestService> service_;
     99   std::unique_ptr<ServiceDispatcher> dispatcher_;
    100   std::thread dispatch_thread_;
    101 };
    102 
    103 class ClientChannelTest : public testing::Test {
    104  public:
    105   void SetUp() override {
    106     int channel_sockets[2] = {};
    107     ASSERT_EQ(
    108         0, socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, channel_sockets));
    109     LocalHandle service_channel{channel_sockets[0]};
    110     LocalHandle client_channel{channel_sockets[1]};
    111 
    112     service_runner_.reset(new TestServiceRunner{std::move(service_channel)});
    113     auto factory = ClientChannelFactory::Create(std::move(client_channel));
    114     auto status = factory->Connect(android::pdx::Client::kInfiniteTimeout);
    115     ASSERT_TRUE(status);
    116     client_ = TestClient::Create(status.take());
    117   }
    118 
    119   void TearDown() override {
    120     service_runner_.reset();
    121     client_.reset();
    122   }
    123 
    124  protected:
    125   std::unique_ptr<TestServiceRunner> service_runner_;
    126   std::shared_ptr<TestClient> client_;
    127 };
    128 
    129 TEST_F(ClientChannelTest, MultithreadedClient) {
    130   constexpr int kNumTestThreads = 8;
    131   constexpr size_t kDataSize = 1000;  // Try to keep RPC buffer size below 4K.
    132 
    133   std::random_device rd;
    134   std::mt19937 gen{rd()};
    135   std::uniform_int_distribution<TestProtocol::DataType> dist{
    136       std::numeric_limits<TestProtocol::DataType>::min(),
    137       std::numeric_limits<TestProtocol::DataType>::max()};
    138 
    139   auto worker = [](std::shared_ptr<TestClient> client,
    140                    std::vector<TestProtocol::DataType> data) {
    141     constexpr int kMaxIterations = 500;
    142     int64_t expected = std::accumulate(data.begin(), data.end(), int64_t{0});
    143     for (int i = 0; i < kMaxIterations; i++) {
    144       ASSERT_EQ(expected, client->Sum(data));
    145     }
    146   };
    147 
    148   // Start client threads.
    149   std::vector<TestProtocol::DataType> data;
    150   data.resize(kDataSize);
    151   std::vector<std::thread> threads;
    152   for (int i = 0; i < kNumTestThreads; i++) {
    153     std::generate(data.begin(), data.end(),
    154                   [&dist, &gen]() { return dist(gen); });
    155     threads.emplace_back(worker, client_, data);
    156   }
    157 
    158   // Wait for threads to finish.
    159   for (auto& thread : threads)
    160     thread.join();
    161 }
    162 
    163 }  // namespace
    164