1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include <functional> 18 #include <map> 19 #include <string> 20 #include <vector> 21 22 #include <google/protobuf/descriptor.h> 23 #include <google/protobuf/compiler/plugin.h> 24 #include <google/protobuf/compiler/code_generator.h> 25 #include <google/protobuf/io/printer.h> 26 #include <google/protobuf/io/zero_copy_stream.h> 27 #include <google/protobuf/stubs/strutil.h> 28 29 #include "nugget/protobuf/options.pb.h" 30 31 using ::google::protobuf::FileDescriptor; 32 using ::google::protobuf::JoinStrings; 33 using ::google::protobuf::MethodDescriptor; 34 using ::google::protobuf::ServiceDescriptor; 35 using ::google::protobuf::Split; 36 using ::google::protobuf::SplitStringUsing; 37 using ::google::protobuf::StripSuffixString; 38 using ::google::protobuf::compiler::CodeGenerator; 39 using ::google::protobuf::compiler::OutputDirectory; 40 using ::google::protobuf::io::Printer; 41 using ::google::protobuf::io::ZeroCopyOutputStream; 42 43 using ::nugget::protobuf::app_id; 44 using ::nugget::protobuf::request_buffer_size; 45 using ::nugget::protobuf::response_buffer_size; 46 47 namespace { 48 49 std::string validateServiceOptions(const ServiceDescriptor& service) { 50 if (!service.options().HasExtension(app_id)) { 51 return "nugget.protobuf.app_id is not defined for service " + service.name(); 52 } 53 if (!service.options().HasExtension(request_buffer_size)) { 54 return "nugget.protobuf.request_buffer_size is not defined for service " + service.name(); 55 } 56 if (!service.options().HasExtension(response_buffer_size)) { 57 return "nugget.protobuf.response_buffer_size is not defined for service " + service.name(); 58 } 59 return ""; 60 } 61 62 template <typename Descriptor> 63 std::vector<std::string> Packages(const Descriptor& descriptor) { 64 std::vector<std::string> namespaces; 65 SplitStringUsing(descriptor.full_name(), ".", &namespaces); 66 namespaces.pop_back(); // just take the package 67 return namespaces; 68 } 69 70 template <typename Descriptor> 71 std::string FullyQualifiedIdentifier(const Descriptor& descriptor) { 72 const auto namespaces = Packages(descriptor); 73 if (namespaces.empty()) { 74 return "::" + descriptor.name(); 75 } else { 76 std::string namespace_path; 77 JoinStrings(namespaces, "::", &namespace_path); 78 return "::" + namespace_path + "::" + descriptor.name(); 79 } 80 } 81 82 template <typename Descriptor> 83 std::string FullyQualifiedHeader(const Descriptor& descriptor) { 84 const auto packages = Packages(descriptor); 85 const auto file = Split(descriptor.file()->name(), "/").back(); 86 const auto header = StripSuffixString(file, ".proto") + ".pb.h"; 87 if (packages.empty()) { 88 return header; 89 } else { 90 std::string package_path; 91 JoinStrings(packages, "/", &package_path); 92 return package_path + "/" + header; 93 } 94 } 95 96 template <typename Descriptor> 97 void OpenNamespaces(Printer& printer, const Descriptor& descriptor) { 98 const auto namespaces = Packages(descriptor); 99 for (const auto& ns : namespaces) { 100 std::map<std::string, std::string> namespaceVars; 101 namespaceVars["namespace"] = ns; 102 printer.Print(namespaceVars, R"( 103 namespace $namespace$ {)"); 104 } 105 } 106 107 template <typename Descriptor> 108 void CloseNamespaces(Printer& printer, const Descriptor& descriptor) { 109 const auto namespaces = Packages(descriptor); 110 for (auto it = namespaces.crbegin(); it != namespaces.crend(); ++it) { 111 std::map<std::string, std::string> namespaceVars; 112 namespaceVars["namespace"] = *it; 113 printer.Print(namespaceVars, R"( 114 } // namespace $namespace$)"); 115 } 116 } 117 118 void ForEachMethod(const ServiceDescriptor& service, 119 std::function<void(std::map<std::string, std::string>)> handler) { 120 for (int i = 0; i < service.method_count(); ++i) { 121 const MethodDescriptor& method = *service.method(i); 122 std::map<std::string, std::string> vars; 123 vars["method_id"] = std::to_string(i); 124 vars["method_name"] = method.name(); 125 vars["method_input_type"] = FullyQualifiedIdentifier(*method.input_type()); 126 vars["method_output_type"] = FullyQualifiedIdentifier(*method.output_type()); 127 handler(vars); 128 } 129 } 130 131 void GenerateMockClient(Printer& printer, const ServiceDescriptor& service) { 132 std::map<std::string, std::string> vars; 133 vars["include_guard"] = "PROTOC_GENERATED_MOCK_" + service.name() + "_CLIENT_H"; 134 vars["service_header"] = service.name() + ".client.h"; 135 vars["mock_class"] = "Mock" + service.name(); 136 vars["class"] = service.name(); 137 138 printer.Print(vars, R"( 139 #ifndef $include_guard$ 140 #define $include_guard$ 141 142 #include <gmock/gmock.h> 143 144 #include <$service_header$>)"); 145 146 OpenNamespaces(printer, service); 147 148 printer.Print(vars, R"( 149 struct $mock_class$ : public I$class$ {)"); 150 151 ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) { 152 printer.Print(methodVars, R"( 153 MOCK_METHOD2($method_name$, uint32_t(const $method_input_type$&, $method_output_type$*));)"); 154 }); 155 156 printer.Print(vars, R"( 157 };)"); 158 159 CloseNamespaces(printer, service); 160 161 printer.Print(vars, R"( 162 #endif)"); 163 } 164 165 void GenerateClientHeader(Printer& printer, const ServiceDescriptor& service) { 166 std::map<std::string, std::string> vars; 167 vars["include_guard"] = "PROTOC_GENERATED_" + service.name() + "_CLIENT_H"; 168 vars["protobuf_header"] = FullyQualifiedHeader(service); 169 vars["class"] = service.name(); 170 vars["iface_class"] = "I" + service.name(); 171 vars["app_id"] = "APP_ID_" + service.options().GetExtension(app_id); 172 173 printer.Print(vars, R"( 174 #ifndef $include_guard$ 175 #define $include_guard$ 176 177 #include <application.h> 178 #include <nos/AppClient.h> 179 #include <nos/NuggetClientInterface.h> 180 181 #include "$protobuf_header$")"); 182 183 OpenNamespaces(printer, service); 184 185 // Pure virtual interface to make testing easier 186 printer.Print(vars, R"( 187 class $iface_class$ { 188 public: 189 virtual ~$iface_class$() = default;)"); 190 191 ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) { 192 printer.Print(methodVars, R"( 193 virtual uint32_t $method_name$(const $method_input_type$&, $method_output_type$*) = 0;)"); 194 }); 195 196 printer.Print(vars, R"( 197 };)"); 198 199 // Implementation of the interface for Nugget 200 printer.Print(vars, R"( 201 class $class$ : public $iface_class$ { 202 ::nos::AppClient _app; 203 public: 204 $class$(::nos::NuggetClientInterface& client) : _app{client, $app_id$} {} 205 ~$class$() override = default;)"); 206 207 ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) { 208 printer.Print(methodVars, R"( 209 uint32_t $method_name$(const $method_input_type$&, $method_output_type$*) override;)"); 210 }); 211 212 printer.Print(vars, R"( 213 };)"); 214 215 CloseNamespaces(printer, service); 216 217 printer.Print(vars, R"( 218 #endif)"); 219 } 220 221 void GenerateClientSource(Printer& printer, const ServiceDescriptor& service) { 222 std::map<std::string, std::string> vars; 223 vars["generated_header"] = service.name() + ".client.h"; 224 vars["class"] = service.name(); 225 226 const uint32_t max_request_size = service.options().GetExtension(request_buffer_size); 227 const uint32_t max_response_size = service.options().GetExtension(response_buffer_size); 228 vars["max_request_size"] = std::to_string(max_request_size); 229 vars["max_response_size"] = std::to_string(max_response_size); 230 231 printer.Print(vars, R"( 232 #include <$generated_header$> 233 234 #include <application.h>)"); 235 236 OpenNamespaces(printer, service); 237 238 // Methods 239 ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) { 240 methodVars.insert(vars.begin(), vars.end()); 241 printer.Print(methodVars, R"( 242 uint32_t $class$::$method_name$(const $method_input_type$& request, $method_output_type$* response) { 243 const size_t request_size = request.ByteSize(); 244 if (request_size > $max_request_size$) { 245 return APP_ERROR_TOO_MUCH; 246 } 247 std::vector<uint8_t> buffer(request_size); 248 if (!request.SerializeToArray(buffer.data(), buffer.size())) { 249 return APP_ERROR_RPC; 250 } 251 std::vector<uint8_t> responseBuffer; 252 if (response != nullptr) { 253 responseBuffer.resize($max_response_size$); 254 } 255 const uint32_t appStatus = _app.Call($method_id$, buffer, 256 (response != nullptr) ? &responseBuffer : nullptr); 257 if (appStatus == APP_SUCCESS && response != nullptr) { 258 if (!response->ParseFromArray(responseBuffer.data(), responseBuffer.size())) { 259 return APP_ERROR_RPC; 260 } 261 } 262 return appStatus; 263 })"); 264 }); 265 266 CloseNamespaces(printer, service); 267 } 268 269 // Generator for C++ Nugget service client 270 class CppNuggetServiceClientGenerator : public CodeGenerator { 271 public: 272 CppNuggetServiceClientGenerator() = default; 273 ~CppNuggetServiceClientGenerator() override = default; 274 275 bool Generate(const FileDescriptor* file, 276 const std::string& parameter, 277 OutputDirectory* output_directory, 278 std::string* error) const override { 279 for (int i = 0; i < file->service_count(); ++i) { 280 const auto& service = *file->service(i); 281 282 *error = validateServiceOptions(service); 283 if (!error->empty()) { 284 return false; 285 } 286 287 if (parameter == "mock") { 288 std::unique_ptr<ZeroCopyOutputStream> output{ 289 output_directory->Open("Mock" + service.name() + ".client.h")}; 290 Printer printer(output.get(), '$'); 291 GenerateMockClient(printer, service); 292 } else if (parameter == "header") { 293 std::unique_ptr<ZeroCopyOutputStream> output{ 294 output_directory->Open(service.name() + ".client.h")}; 295 Printer printer(output.get(), '$'); 296 GenerateClientHeader(printer, service); 297 } else if (parameter == "source") { 298 std::unique_ptr<ZeroCopyOutputStream> output{ 299 output_directory->Open(service.name() + ".client.cpp")}; 300 Printer printer(output.get(), '$'); 301 GenerateClientSource(printer, service); 302 } else { 303 *error = "Illegal parameter: must be mock|header|source"; 304 return false; 305 } 306 } 307 308 return true; 309 } 310 311 private: 312 GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CppNuggetServiceClientGenerator); 313 }; 314 315 } // namespace 316 317 int main(int argc, char* argv[]) { 318 GOOGLE_PROTOBUF_VERIFY_VERSION; 319 CppNuggetServiceClientGenerator generator; 320 return google::protobuf::compiler::PluginMain(argc, argv, &generator); 321 } 322