1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "src/tracing/test/mock_producer.h" 18 19 #include "perfetto/tracing/core/data_source_config.h" 20 #include "perfetto/tracing/core/data_source_descriptor.h" 21 #include "perfetto/tracing/core/trace_writer.h" 22 #include "src/base/test/test_task_runner.h" 23 24 using ::testing::_; 25 using ::testing::Eq; 26 using ::testing::Invoke; 27 using ::testing::InvokeWithoutArgs; 28 using ::testing::Property; 29 30 namespace perfetto { 31 32 MockProducer::MockProducer(base::TestTaskRunner* task_runner) 33 : task_runner_(task_runner) {} 34 35 MockProducer::~MockProducer() { 36 if (!service_endpoint_) 37 return; 38 static int i = 0; 39 auto checkpoint_name = "on_producer_disconnect_" + std::to_string(i++); 40 auto on_disconnect = task_runner_->CreateCheckpoint(checkpoint_name); 41 EXPECT_CALL(*this, OnDisconnect()).WillOnce(Invoke(on_disconnect)); 42 service_endpoint_.reset(); 43 task_runner_->RunUntilCheckpoint(checkpoint_name); 44 } 45 46 void MockProducer::Connect(TracingService* svc, 47 const std::string& producer_name, 48 uid_t uid, 49 size_t shared_memory_size_hint_bytes) { 50 producer_name_ = producer_name; 51 service_endpoint_ = 52 svc->ConnectProducer(this, uid, producer_name, 53 shared_memory_size_hint_bytes, /*in_process=*/true); 54 auto checkpoint_name = "on_producer_connect_" + producer_name; 55 auto on_connect = task_runner_->CreateCheckpoint(checkpoint_name); 56 EXPECT_CALL(*this, OnConnect()).WillOnce(Invoke(on_connect)); 57 task_runner_->RunUntilCheckpoint(checkpoint_name); 58 } 59 60 void MockProducer::RegisterDataSource(const std::string& name, 61 bool ack_stop, 62 bool ack_start, 63 bool handle_incremental_state_clear) { 64 DataSourceDescriptor ds_desc; 65 ds_desc.set_name(name); 66 ds_desc.set_will_notify_on_stop(ack_stop); 67 ds_desc.set_will_notify_on_start(ack_start); 68 ds_desc.set_handles_incremental_state_clear(handle_incremental_state_clear); 69 service_endpoint_->RegisterDataSource(ds_desc); 70 } 71 72 void MockProducer::UnregisterDataSource(const std::string& name) { 73 service_endpoint_->UnregisterDataSource(name); 74 } 75 76 void MockProducer::RegisterTraceWriter(uint32_t writer_id, 77 uint32_t target_buffer) { 78 service_endpoint_->RegisterTraceWriter(writer_id, target_buffer); 79 } 80 81 void MockProducer::UnregisterTraceWriter(uint32_t writer_id) { 82 service_endpoint_->UnregisterTraceWriter(writer_id); 83 } 84 85 void MockProducer::WaitForTracingSetup() { 86 static int i = 0; 87 auto checkpoint_name = 88 "on_shmem_initialized_" + producer_name_ + "_" + std::to_string(i++); 89 auto on_tracing_enabled = task_runner_->CreateCheckpoint(checkpoint_name); 90 EXPECT_CALL(*this, OnTracingSetup()).WillOnce(Invoke(on_tracing_enabled)); 91 task_runner_->RunUntilCheckpoint(checkpoint_name); 92 } 93 94 void MockProducer::WaitForDataSourceSetup(const std::string& name) { 95 static int i = 0; 96 auto checkpoint_name = "on_ds_setup_" + name + "_" + std::to_string(i++); 97 auto on_ds_start = task_runner_->CreateCheckpoint(checkpoint_name); 98 EXPECT_CALL(*this, 99 SetupDataSource(_, Property(&DataSourceConfig::name, Eq(name)))) 100 .WillOnce(Invoke([on_ds_start, this](DataSourceInstanceID ds_id, 101 const DataSourceConfig& cfg) { 102 EXPECT_FALSE(data_source_instances_.count(cfg.name())); 103 auto target_buffer = static_cast<BufferID>(cfg.target_buffer()); 104 auto session_id = 105 static_cast<TracingSessionID>(cfg.tracing_session_id()); 106 data_source_instances_.emplace( 107 cfg.name(), EnabledDataSource{ds_id, target_buffer, session_id}); 108 on_ds_start(); 109 })); 110 task_runner_->RunUntilCheckpoint(checkpoint_name); 111 } 112 113 void MockProducer::WaitForDataSourceStart(const std::string& name) { 114 static int i = 0; 115 auto checkpoint_name = "on_ds_start_" + name + "_" + std::to_string(i++); 116 auto on_ds_start = task_runner_->CreateCheckpoint(checkpoint_name); 117 EXPECT_CALL(*this, 118 StartDataSource(_, Property(&DataSourceConfig::name, Eq(name)))) 119 .WillOnce(Invoke([on_ds_start, this](DataSourceInstanceID ds_id, 120 const DataSourceConfig& cfg) { 121 // The data source might have been seen already through 122 // WaitForDataSourceSetup(). 123 if (data_source_instances_.count(cfg.name()) == 0) { 124 auto target_buffer = static_cast<BufferID>(cfg.target_buffer()); 125 auto session_id = 126 static_cast<TracingSessionID>(cfg.tracing_session_id()); 127 data_source_instances_.emplace( 128 cfg.name(), EnabledDataSource{ds_id, target_buffer, session_id}); 129 } 130 on_ds_start(); 131 })); 132 task_runner_->RunUntilCheckpoint(checkpoint_name); 133 } 134 135 void MockProducer::WaitForDataSourceStop(const std::string& name) { 136 static int i = 0; 137 auto checkpoint_name = "on_ds_stop_" + name + "_" + std::to_string(i++); 138 auto on_ds_stop = task_runner_->CreateCheckpoint(checkpoint_name); 139 ASSERT_EQ(1u, data_source_instances_.count(name)); 140 DataSourceInstanceID ds_id = data_source_instances_[name].id; 141 EXPECT_CALL(*this, StopDataSource(ds_id)) 142 .WillOnce(InvokeWithoutArgs(on_ds_stop)); 143 task_runner_->RunUntilCheckpoint(checkpoint_name); 144 data_source_instances_.erase(name); 145 } 146 147 std::unique_ptr<TraceWriter> MockProducer::CreateTraceWriter( 148 const std::string& data_source_name) { 149 PERFETTO_DCHECK(data_source_instances_.count(data_source_name)); 150 BufferID buf_id = data_source_instances_[data_source_name].target_buffer; 151 return service_endpoint_->CreateTraceWriter(buf_id); 152 } 153 154 void MockProducer::WaitForFlush(TraceWriter* writer_to_flush, bool reply) { 155 std::vector<TraceWriter*> writers; 156 if (writer_to_flush) 157 writers.push_back(writer_to_flush); 158 WaitForFlush(writers, reply); 159 } 160 161 void MockProducer::WaitForFlush(std::vector<TraceWriter*> writers_to_flush, 162 bool reply) { 163 auto& expected_call = EXPECT_CALL(*this, Flush(_, _, _)); 164 expected_call.WillOnce(Invoke( 165 [this, writers_to_flush, reply](FlushRequestID flush_req_id, 166 const DataSourceInstanceID*, size_t) { 167 for (auto* writer : writers_to_flush) 168 writer->Flush(); 169 if (reply) 170 service_endpoint_->NotifyFlushComplete(flush_req_id); 171 })); 172 } 173 174 DataSourceInstanceID MockProducer::GetDataSourceInstanceId( 175 const std::string& name) { 176 auto it = data_source_instances_.find(name); 177 return it == data_source_instances_.end() ? 0 : it->second.id; 178 } 179 180 const MockProducer::EnabledDataSource* MockProducer::GetDataSourceInstance( 181 const std::string& name) { 182 auto it = data_source_instances_.find(name); 183 return it == data_source_instances_.end() ? nullptr : &it->second; 184 } 185 186 } // namespace perfetto 187