1 #include <uds/client_channel.h> 2 3 #include <sys/socket.h> 4 5 #include <algorithm> 6 #include <limits> 7 #include <random> 8 #include <thread> 9 10 #include <gmock/gmock.h> 11 #include <gtest/gtest.h> 12 13 #include <pdx/client.h> 14 #include <pdx/rpc/remote_method.h> 15 #include <pdx/service.h> 16 #include <pdx/service_dispatcher.h> 17 18 #include <uds/client_channel_factory.h> 19 #include <uds/service_endpoint.h> 20 21 using testing::Return; 22 using testing::_; 23 24 using android::pdx::ClientBase; 25 using android::pdx::LocalChannelHandle; 26 using android::pdx::LocalHandle; 27 using android::pdx::Message; 28 using android::pdx::ServiceBase; 29 using android::pdx::ServiceDispatcher; 30 using android::pdx::Status; 31 using android::pdx::rpc::DispatchRemoteMethod; 32 using android::pdx::uds::ClientChannel; 33 using android::pdx::uds::ClientChannelFactory; 34 using android::pdx::uds::Endpoint; 35 36 namespace { 37 38 struct TestProtocol { 39 using DataType = int8_t; 40 enum { 41 kOpSum = 0, 42 }; 43 PDX_REMOTE_METHOD(Sum, kOpSum, int64_t(const std::vector<DataType>&)); 44 }; 45 46 class TestService : public ServiceBase<TestService> { 47 public: 48 explicit TestService(std::unique_ptr<Endpoint> endpoint) 49 : ServiceBase{"TestService", std::move(endpoint)} {} 50 51 Status<void> HandleMessage(Message& message) override { 52 switch (message.GetOp()) { 53 case TestProtocol::kOpSum: 54 DispatchRemoteMethod<TestProtocol::Sum>(*this, &TestService::OnSum, 55 message); 56 return {}; 57 58 default: 59 return Service::HandleMessage(message); 60 } 61 } 62 63 int64_t OnSum(Message& /*message*/, 64 const std::vector<TestProtocol::DataType>& data) { 65 return std::accumulate(data.begin(), data.end(), int64_t{0}); 66 } 67 }; 68 69 class TestClient : public ClientBase<TestClient> { 70 public: 71 using ClientBase::ClientBase; 72 73 int64_t Sum(const std::vector<TestProtocol::DataType>& data) { 74 auto status = InvokeRemoteMethod<TestProtocol::Sum>(data); 75 return status ? status.get() : -1; 76 } 77 }; 78 79 class TestServiceRunner { 80 public: 81 explicit TestServiceRunner(LocalHandle channel_socket) { 82 auto endpoint = Endpoint::CreateFromSocketFd(LocalHandle{}); 83 endpoint->RegisterNewChannelForTests(std::move(channel_socket)); 84 service_ = TestService::Create(std::move(endpoint)); 85 dispatcher_ = ServiceDispatcher::Create(); 86 dispatcher_->AddService(service_); 87 dispatch_thread_ = std::thread( 88 std::bind(&ServiceDispatcher::EnterDispatchLoop, dispatcher_.get())); 89 } 90 91 ~TestServiceRunner() { 92 dispatcher_->SetCanceled(true); 93 dispatch_thread_.join(); 94 dispatcher_->RemoveService(service_); 95 } 96 97 private: 98 std::shared_ptr<TestService> service_; 99 std::unique_ptr<ServiceDispatcher> dispatcher_; 100 std::thread dispatch_thread_; 101 }; 102 103 class ClientChannelTest : public testing::Test { 104 public: 105 void SetUp() override { 106 int channel_sockets[2] = {}; 107 ASSERT_EQ( 108 0, socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, channel_sockets)); 109 LocalHandle service_channel{channel_sockets[0]}; 110 LocalHandle client_channel{channel_sockets[1]}; 111 112 service_runner_.reset(new TestServiceRunner{std::move(service_channel)}); 113 auto factory = ClientChannelFactory::Create(std::move(client_channel)); 114 auto status = factory->Connect(android::pdx::Client::kInfiniteTimeout); 115 ASSERT_TRUE(status); 116 client_ = TestClient::Create(status.take()); 117 } 118 119 void TearDown() override { 120 service_runner_.reset(); 121 client_.reset(); 122 } 123 124 protected: 125 std::unique_ptr<TestServiceRunner> service_runner_; 126 std::shared_ptr<TestClient> client_; 127 }; 128 129 TEST_F(ClientChannelTest, MultithreadedClient) { 130 constexpr int kNumTestThreads = 8; 131 constexpr size_t kDataSize = 1000; // Try to keep RPC buffer size below 4K. 132 133 std::random_device rd; 134 std::mt19937 gen{rd()}; 135 std::uniform_int_distribution<TestProtocol::DataType> dist{ 136 std::numeric_limits<TestProtocol::DataType>::min(), 137 std::numeric_limits<TestProtocol::DataType>::max()}; 138 139 auto worker = [](std::shared_ptr<TestClient> client, 140 std::vector<TestProtocol::DataType> data) { 141 constexpr int kMaxIterations = 500; 142 int64_t expected = std::accumulate(data.begin(), data.end(), int64_t{0}); 143 for (int i = 0; i < kMaxIterations; i++) { 144 ASSERT_EQ(expected, client->Sum(data)); 145 } 146 }; 147 148 // Start client threads. 149 std::vector<TestProtocol::DataType> data; 150 data.resize(kDataSize); 151 std::vector<std::thread> threads; 152 for (int i = 0; i < kNumTestThreads; i++) { 153 std::generate(data.begin(), data.end(), 154 [&dist, &gen]() { return dist(gen); }); 155 threads.emplace_back(worker, client_, data); 156 } 157 158 // Wait for threads to finish. 159 for (auto& thread : threads) 160 thread.join(); 161 } 162 163 } // namespace 164