Home | History | Annotate | Download | only in generator
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include <functional>
     18 #include <map>
     19 #include <string>
     20 #include <vector>
     21 
     22 #include <google/protobuf/descriptor.h>
     23 #include <google/protobuf/compiler/plugin.h>
     24 #include <google/protobuf/compiler/code_generator.h>
     25 #include <google/protobuf/io/printer.h>
     26 #include <google/protobuf/io/zero_copy_stream.h>
     27 #include <google/protobuf/stubs/strutil.h>
     28 
     29 #include "nugget/protobuf/options.pb.h"
     30 
     31 using ::google::protobuf::FileDescriptor;
     32 using ::google::protobuf::JoinStrings;
     33 using ::google::protobuf::MethodDescriptor;
     34 using ::google::protobuf::ServiceDescriptor;
     35 using ::google::protobuf::Split;
     36 using ::google::protobuf::SplitStringUsing;
     37 using ::google::protobuf::StripSuffixString;
     38 using ::google::protobuf::compiler::CodeGenerator;
     39 using ::google::protobuf::compiler::OutputDirectory;
     40 using ::google::protobuf::io::Printer;
     41 using ::google::protobuf::io::ZeroCopyOutputStream;
     42 
     43 using ::nugget::protobuf::app_id;
     44 using ::nugget::protobuf::request_buffer_size;
     45 using ::nugget::protobuf::response_buffer_size;
     46 
     47 namespace {
     48 
     49 std::string validateServiceOptions(const ServiceDescriptor& service) {
     50     if (!service.options().HasExtension(app_id)) {
     51         return "nugget.protobuf.app_id is not defined for service " + service.name();
     52     }
     53     if (!service.options().HasExtension(request_buffer_size)) {
     54         return "nugget.protobuf.request_buffer_size is not defined for service " + service.name();
     55     }
     56     if (!service.options().HasExtension(response_buffer_size)) {
     57         return "nugget.protobuf.response_buffer_size is not defined for service " + service.name();
     58     }
     59     return "";
     60 }
     61 
     62 template <typename Descriptor>
     63 std::vector<std::string> Packages(const Descriptor& descriptor) {
     64     std::vector<std::string> namespaces;
     65     SplitStringUsing(descriptor.full_name(), ".", &namespaces);
     66     namespaces.pop_back(); // just take the package
     67     return namespaces;
     68 }
     69 
     70 template <typename Descriptor>
     71 std::string FullyQualifiedIdentifier(const Descriptor& descriptor) {
     72     const auto namespaces = Packages(descriptor);
     73     if (namespaces.empty()) {
     74         return "::" + descriptor.name();
     75     } else {
     76         std::string namespace_path;
     77         JoinStrings(namespaces, "::", &namespace_path);
     78         return "::" + namespace_path + "::" + descriptor.name();
     79     }
     80 }
     81 
     82 template <typename Descriptor>
     83 std::string FullyQualifiedHeader(const Descriptor& descriptor) {
     84     const auto packages = Packages(descriptor);
     85     const auto file = Split(descriptor.file()->name(), "/").back();
     86     const auto header = StripSuffixString(file, ".proto") + ".pb.h";
     87     if (packages.empty()) {
     88         return header;
     89     } else {
     90         std::string package_path;
     91         JoinStrings(packages, "/", &package_path);
     92         return package_path + "/" + header;
     93     }
     94 }
     95 
     96 template <typename Descriptor>
     97 void OpenNamespaces(Printer& printer, const Descriptor& descriptor) {
     98     const auto namespaces = Packages(descriptor);
     99     for (const auto& ns : namespaces) {
    100         std::map<std::string, std::string> namespaceVars;
    101         namespaceVars["namespace"] = ns;
    102         printer.Print(namespaceVars, R"(
    103 namespace $namespace$ {)");
    104     }
    105 }
    106 
    107 template <typename Descriptor>
    108 void CloseNamespaces(Printer& printer, const Descriptor& descriptor) {
    109     const auto namespaces = Packages(descriptor);
    110     for (auto it = namespaces.crbegin(); it != namespaces.crend(); ++it) {
    111         std::map<std::string, std::string> namespaceVars;
    112         namespaceVars["namespace"] = *it;
    113         printer.Print(namespaceVars, R"(
    114 } // namespace $namespace$)");
    115     }
    116 }
    117 
    118 void ForEachMethod(const ServiceDescriptor& service,
    119                    std::function<void(std::map<std::string, std::string>)> handler) {
    120     for (int i = 0; i < service.method_count(); ++i) {
    121         const MethodDescriptor& method = *service.method(i);
    122         std::map<std::string, std::string> vars;
    123         vars["method_id"] = std::to_string(i);
    124         vars["method_name"] = method.name();
    125         vars["method_input_type"] = FullyQualifiedIdentifier(*method.input_type());
    126         vars["method_output_type"] = FullyQualifiedIdentifier(*method.output_type());
    127         handler(vars);
    128     }
    129 }
    130 
    131 void GenerateMockClient(Printer& printer, const ServiceDescriptor& service) {
    132     std::map<std::string, std::string> vars;
    133     vars["include_guard"] = "PROTOC_GENERATED_MOCK_" + service.name() + "_CLIENT_H";
    134     vars["service_header"] = service.name() + ".client.h";
    135     vars["mock_class"] = "Mock" + service.name();
    136     vars["class"] = service.name();
    137 
    138     printer.Print(vars, R"(
    139 #ifndef $include_guard$
    140 #define $include_guard$
    141 
    142 #include <gmock/gmock.h>
    143 
    144 #include <$service_header$>)");
    145 
    146     OpenNamespaces(printer, service);
    147 
    148     printer.Print(vars, R"(
    149 struct $mock_class$ : public I$class$ {)");
    150 
    151     ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) {
    152         printer.Print(methodVars, R"(
    153     MOCK_METHOD2($method_name$, uint32_t(const $method_input_type$&, $method_output_type$*));)");
    154     });
    155 
    156     printer.Print(vars, R"(
    157 };)");
    158 
    159     CloseNamespaces(printer, service);
    160 
    161     printer.Print(vars, R"(
    162 #endif)");
    163 }
    164 
    165 void GenerateClientHeader(Printer& printer, const ServiceDescriptor& service) {
    166     std::map<std::string, std::string> vars;
    167     vars["include_guard"] = "PROTOC_GENERATED_" + service.name() + "_CLIENT_H";
    168     vars["protobuf_header"] = FullyQualifiedHeader(service);
    169     vars["class"] = service.name();
    170     vars["iface_class"] = "I" + service.name();
    171     vars["app_id"] = "APP_ID_" + service.options().GetExtension(app_id);
    172 
    173     printer.Print(vars, R"(
    174 #ifndef $include_guard$
    175 #define $include_guard$
    176 
    177 #include <application.h>
    178 #include <nos/AppClient.h>
    179 #include <nos/NuggetClientInterface.h>
    180 
    181 #include "$protobuf_header$")");
    182 
    183     OpenNamespaces(printer, service);
    184 
    185     // Pure virtual interface to make testing easier
    186     printer.Print(vars, R"(
    187 class $iface_class$ {
    188 public:
    189     virtual ~$iface_class$() = default;)");
    190 
    191     ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) {
    192         printer.Print(methodVars, R"(
    193     virtual uint32_t $method_name$(const $method_input_type$&, $method_output_type$*) = 0;)");
    194     });
    195 
    196     printer.Print(vars, R"(
    197 };)");
    198 
    199     // Implementation of the interface for Nugget
    200     printer.Print(vars, R"(
    201 class $class$ : public $iface_class$ {
    202     ::nos::AppClient _app;
    203 public:
    204     $class$(::nos::NuggetClientInterface& client) : _app{client, $app_id$} {}
    205     ~$class$() override = default;)");
    206 
    207     ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) {
    208         printer.Print(methodVars, R"(
    209     uint32_t $method_name$(const $method_input_type$&, $method_output_type$*) override;)");
    210     });
    211 
    212     printer.Print(vars, R"(
    213 };)");
    214 
    215     CloseNamespaces(printer, service);
    216 
    217     printer.Print(vars, R"(
    218 #endif)");
    219 }
    220 
    221 void GenerateClientSource(Printer& printer, const ServiceDescriptor& service) {
    222     std::map<std::string, std::string> vars;
    223     vars["generated_header"] = service.name() + ".client.h";
    224     vars["class"] = service.name();
    225 
    226     const uint32_t max_request_size = service.options().GetExtension(request_buffer_size);
    227     const uint32_t max_response_size = service.options().GetExtension(response_buffer_size);
    228     vars["max_request_size"] = std::to_string(max_request_size);
    229     vars["max_response_size"] = std::to_string(max_response_size);
    230 
    231     printer.Print(vars, R"(
    232 #include <$generated_header$>
    233 
    234 #include <application.h>)");
    235 
    236     OpenNamespaces(printer, service);
    237 
    238     // Methods
    239     ForEachMethod(service, [&](std::map<std::string, std::string>  methodVars) {
    240         methodVars.insert(vars.begin(), vars.end());
    241         printer.Print(methodVars, R"(
    242 uint32_t $class$::$method_name$(const $method_input_type$& request, $method_output_type$* response) {
    243     const size_t request_size = request.ByteSize();
    244     if (request_size > $max_request_size$) {
    245         return APP_ERROR_TOO_MUCH;
    246     }
    247     std::vector<uint8_t> buffer(request_size);
    248     if (!request.SerializeToArray(buffer.data(), buffer.size())) {
    249         return APP_ERROR_RPC;
    250     }
    251     std::vector<uint8_t> responseBuffer;
    252     if (response != nullptr) {
    253       responseBuffer.resize($max_response_size$);
    254     }
    255     const uint32_t appStatus = _app.Call($method_id$, buffer,
    256                                          (response != nullptr) ? &responseBuffer : nullptr);
    257     if (appStatus == APP_SUCCESS && response != nullptr) {
    258         if (!response->ParseFromArray(responseBuffer.data(), responseBuffer.size())) {
    259             return APP_ERROR_RPC;
    260         }
    261     }
    262     return appStatus;
    263 })");
    264     });
    265 
    266     CloseNamespaces(printer, service);
    267 }
    268 
    269 // Generator for C++ Nugget service client
    270 class CppNuggetServiceClientGenerator : public CodeGenerator {
    271 public:
    272     CppNuggetServiceClientGenerator() = default;
    273     ~CppNuggetServiceClientGenerator() override = default;
    274 
    275     bool Generate(const FileDescriptor* file,
    276                   const std::string& parameter,
    277                   OutputDirectory* output_directory,
    278                   std::string* error) const override {
    279         for (int i = 0; i < file->service_count(); ++i) {
    280             const auto& service = *file->service(i);
    281 
    282             *error = validateServiceOptions(service);
    283             if (!error->empty()) {
    284                 return false;
    285             }
    286 
    287             if (parameter == "mock") {
    288                 std::unique_ptr<ZeroCopyOutputStream> output{
    289                         output_directory->Open("Mock" + service.name() + ".client.h")};
    290                 Printer printer(output.get(), '$');
    291                 GenerateMockClient(printer, service);
    292             } else if (parameter == "header") {
    293                 std::unique_ptr<ZeroCopyOutputStream> output{
    294                         output_directory->Open(service.name() + ".client.h")};
    295                 Printer printer(output.get(), '$');
    296                 GenerateClientHeader(printer, service);
    297             } else if (parameter == "source") {
    298                 std::unique_ptr<ZeroCopyOutputStream> output{
    299                         output_directory->Open(service.name() + ".client.cpp")};
    300                 Printer printer(output.get(), '$');
    301                 GenerateClientSource(printer, service);
    302             } else {
    303                 *error = "Illegal parameter: must be mock|header|source";
    304                 return false;
    305             }
    306         }
    307 
    308         return true;
    309     }
    310 
    311 private:
    312     GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CppNuggetServiceClientGenerator);
    313 };
    314 
    315 } // namespace
    316 
    317 int main(int argc, char* argv[]) {
    318     GOOGLE_PROTOBUF_VERIFY_VERSION;
    319     CppNuggetServiceClientGenerator generator;
    320     return google::protobuf::compiler::PluginMain(argc, argv, &generator);
    321 }
    322