/ src / ipc / libmultiprocess / src / mp / gen.cpp
gen.cpp
  1  // Copyright (c) The Bitcoin Core developers
  2  // Distributed under the MIT software license, see the accompanying
  3  // file COPYING or http://www.opensource.org/licenses/mit-license.php.
  4  
  5  #include <mp/config.h>
  6  #include <mp/util.h>
  7  
  8  #include <algorithm>
  9  #include <capnp/schema.h>
 10  #include <capnp/schema-parser.h>
 11  #include <cerrno>
 12  #include <cstdint>
 13  #include <cstdio>
 14  #include <cstdlib>
 15  #include <fstream>
 16  #include <functional>
 17  #include <initializer_list>
 18  #include <iostream>
 19  #include <kj/array.h>
 20  #include <kj/common.h>
 21  #include <kj/filesystem.h>
 22  #include <kj/memory.h>
 23  #include <kj/string.h>
 24  #include <map>
 25  #include <set>
 26  #include <sstream>
 27  #include <stdexcept>
 28  #include <string>
 29  #include <system_error>
 30  #include <unistd.h>
 31  #include <utility>
 32  #include <vector>
 33  
 34  #define PROXY_BIN "mpgen"
 35  #define PROXY_DECL "mp/proxy.h"
 36  #define PROXY_TYPES "mp/proxy-types.h"
 37  
 38  constexpr uint64_t NAMESPACE_ANNOTATION_ID = 0xb9c6f99ebf805f2cull; // From c++.capnp
 39  constexpr uint64_t INCLUDE_ANNOTATION_ID = 0xb899f3c154fdb458ull;   // From proxy.capnp
 40  constexpr uint64_t INCLUDE_TYPES_ANNOTATION_ID = 0xbcec15648e8a0cf1ull; // From proxy.capnp
 41  constexpr uint64_t WRAP_ANNOTATION_ID = 0xe6f46079b7b1405eull;      // From proxy.capnp
 42  constexpr uint64_t COUNT_ANNOTATION_ID = 0xd02682b319f69b38ull;     // From proxy.capnp
 43  constexpr uint64_t EXCEPTION_ANNOTATION_ID = 0x996a183200992f88ull; // From proxy.capnp
 44  constexpr uint64_t NAME_ANNOTATION_ID = 0xb594888f63f4dbb9ull;      // From proxy.capnp
 45  constexpr uint64_t SKIP_ANNOTATION_ID = 0x824c08b82695d8ddull;      // From proxy.capnp
 46  
 47  template <typename Reader>
 48  static bool AnnotationExists(const Reader& reader, uint64_t id)
 49  {
 50      for (const auto annotation : reader.getAnnotations()) {
 51          if (annotation.getId() == id) {
 52              return true;
 53          }
 54      }
 55      return false;
 56  }
 57  
 58  template <typename Reader>
 59  static bool GetAnnotationText(const Reader& reader, uint64_t id, kj::StringPtr* result)
 60  {
 61      for (const auto annotation : reader.getAnnotations()) {
 62          if (annotation.getId() == id) {
 63              *result = annotation.getValue().getText();
 64              return true;
 65          }
 66      }
 67      return false;
 68  }
 69  
 70  template <typename Reader>
 71  static bool GetAnnotationInt32(const Reader& reader, uint64_t id, int32_t* result)
 72  {
 73      for (const auto annotation : reader.getAnnotations()) {
 74          if (annotation.getId() == id) {
 75              *result = annotation.getValue().getInt32();
 76              return true;
 77          }
 78      }
 79      return false;
 80  }
 81  
 82  static void ForEachMethod(const capnp::InterfaceSchema& interface, const std::function<void(const capnp::InterfaceSchema& interface, const capnp::InterfaceSchema::Method)>& callback) // NOLINT(misc-no-recursion)
 83  {
 84      for (const auto super : interface.getSuperclasses()) {
 85          ForEachMethod(super, callback);
 86      }
 87      for (const auto method : interface.getMethods()) {
 88          callback(interface, method);
 89      }
 90  }
 91  
 92  using CharSlice = kj::ArrayPtr<const char>;
 93  
 94  // Overload for any type with a string .begin(), like kj::StringPtr and kj::ArrayPtr<char>.
 95  template <class OutputStream, class Array, const char* Enable = decltype(std::declval<Array>().begin())()>
 96  static OutputStream& operator<<(OutputStream& os, const Array& array)
 97  {
 98      os.write(array.begin(), array.size());
 99      return os;
100  }
101  
102  struct Format
103  {
104      template <typename Value>
105      Format& operator<<(Value&& value)
106      {
107          m_os << value;
108          return *this;
109      }
110      operator std::string() const { return m_os.str(); }
111      std::ostringstream m_os;
112  };
113  
114  static std::string Cap(kj::StringPtr str)
115  {
116      std::string result = str;
117      if (!result.empty() && 'a' <= result[0] && result[0] <= 'z') result[0] -= 'a' - 'A';
118      return result;
119  }
120  
121  static bool BoxedType(const ::capnp::Type& type)
122  {
123      return !(type.isVoid() || type.isBool() || type.isInt8() || type.isInt16() || type.isInt32() || type.isInt64() ||
124               type.isUInt8() || type.isUInt16() || type.isUInt32() || type.isUInt64() || type.isFloat32() ||
125               type.isFloat64() || type.isEnum());
126  }
127  
128  // src_file is path to .capnp file to generate stub code from.
129  //
130  // src_prefix can be used to generate outputs in a different directory than the
131  // source directory. For example if src_file is "/a/b/c/d/file.canp", and
132  // src_prefix is "/a/b", then output files will be "c/d/file.capnp.h"
133  // "c/d/file.capnp.cxx" "c/d/file.capnp.proxy.h", etc. This is equivalent to
134  // the capnp "--src-prefix" option (see "capnp help compile").
135  //
136  // include_prefix can be used to control relative include paths used in
137  // generated files. For example if src_file is "/a/b/c/d/file.canp" and
138  // include_prefix is "/a/b/c" include lines like
139  // "#include <d/file.capnp.proxy.h>", "#include <d/file.capnp.proxy-types.h>"
140  // will be generated.
141  static void Generate(kj::StringPtr src_prefix,
142      kj::StringPtr include_prefix,
143      kj::StringPtr src_file,
144      const std::vector<kj::StringPtr>& import_paths,
145      const kj::ReadableDirectory& src_dir,
146      const std::vector<kj::Own<const kj::ReadableDirectory>>& import_dirs)
147  {
148      std::string output_path;
149      if (src_prefix == kj::StringPtr{"."}) {
150          output_path = src_file;
151      } else if (!src_file.startsWith(src_prefix) || src_file.size() <= src_prefix.size() ||
152                 src_file[src_prefix.size()] != '/') {
153          throw std::runtime_error("src_prefix is not src_file prefix");
154      } else {
155          output_path = src_file.slice(src_prefix.size() + 1);
156      }
157  
158      std::string include_path;
159      if (include_prefix == kj::StringPtr{"."}) {
160          include_path = src_file;
161      } else if (!src_file.startsWith(include_prefix) || src_file.size() <= include_prefix.size() ||
162                 src_file[include_prefix.size()] != '/') {
163          throw std::runtime_error("include_prefix is not src_file prefix");
164      } else {
165          include_path = src_file.slice(include_prefix.size() + 1);
166      }
167  
168      std::string include_base = include_path;
169      const std::string::size_type p = include_base.rfind('.');
170      if (p != std::string::npos) include_base.erase(p);
171  
172      std::vector<std::string> args;
173      args.emplace_back(capnp_PREFIX "/bin/capnp");
174      args.emplace_back("compile");
175      args.emplace_back("--src-prefix=");
176      args.back().append(src_prefix.cStr(), src_prefix.size());
177      for (const auto& import_path : import_paths) {
178          args.emplace_back("--import-path=");
179          args.back().append(import_path.cStr(), import_path.size());
180      }
181      args.emplace_back("--output=" capnp_PREFIX "/bin/capnpc-c++");
182      args.emplace_back(src_file);
183      const int pid = fork();
184      if (pid == -1) {
185          throw std::system_error(errno, std::system_category(), "fork");
186      }
187      if (!pid) {
188          mp::ExecProcess(args);
189      }
190      const int status = mp::WaitProcess(pid);
191      if (status) {
192          throw std::runtime_error("Invoking " capnp_PREFIX "/bin/capnp failed");
193      }
194  
195      const capnp::SchemaParser parser;
196      auto directory_pointers = kj::heapArray<const kj::ReadableDirectory*>(import_dirs.size());
197      for (size_t i = 0; i < import_dirs.size(); ++i) {
198          directory_pointers[i] = import_dirs[i].get();
199      }
200      auto file_schema = parser.parseFromDirectory(src_dir, kj::Path::parse(output_path), directory_pointers);
201  
202      std::ofstream cpp_server(output_path + ".proxy-server.c++");
203      cpp_server << "// Generated by " PROXY_BIN " from " << src_file << "\n\n";
204      cpp_server << "// IWYU pragma: no_include <kj/memory.h>\n";
205      cpp_server << "// IWYU pragma: no_include <memory>\n";
206      cpp_server << "// IWYU pragma: begin_keep\n";
207      cpp_server << "#include <" << include_path << ".proxy.h>\n";
208      cpp_server << "#include <" << include_path << ".proxy-types.h>\n";
209      cpp_server << "#include <capnp/generated-header-support.h>\n";
210      cpp_server << "#include <cstring>\n";
211      cpp_server << "#include <kj/async.h>\n";
212      cpp_server << "#include <kj/common.h>\n";
213      cpp_server << "#include <kj/exception.h>\n";
214      cpp_server << "#include <kj/tuple.h>\n";
215      cpp_server << "#include <mp/proxy.h>\n";
216      cpp_server << "#include <mp/util.h>\n";
217      cpp_server << "#include <" << PROXY_TYPES << ">\n";
218      cpp_server << "// IWYU pragma: end_keep\n\n";
219      cpp_server << "namespace mp {\n";
220  
221      std::ofstream cpp_client(output_path + ".proxy-client.c++");
222      cpp_client << "// Generated by " PROXY_BIN " from " << src_file << "\n\n";
223      cpp_client << "// IWYU pragma: no_include <kj/memory.h>\n";
224      cpp_client << "// IWYU pragma: no_include <memory>\n";
225      cpp_client << "// IWYU pragma: begin_keep\n";
226      cpp_client << "#include <" << include_path << ".h>\n";
227      cpp_client << "#include <" << include_path << ".proxy.h>\n";
228      cpp_client << "#include <" << include_path << ".proxy-types.h>\n";
229      cpp_client << "#include <capnp/generated-header-support.h>\n";
230      cpp_client << "#include <cstring>\n";
231      cpp_client << "#include <kj/common.h>\n";
232      cpp_client << "#include <mp/proxy.h>\n";
233      cpp_client << "#include <mp/util.h>\n";
234      cpp_client << "#include <" << PROXY_TYPES << ">\n";
235      cpp_client << "// IWYU pragma: end_keep\n\n";
236      cpp_client << "namespace mp {\n";
237  
238      std::ofstream cpp_types(output_path + ".proxy-types.c++");
239      cpp_types << "// Generated by " PROXY_BIN " from " << src_file << "\n\n";
240      cpp_types << "// IWYU pragma: no_include \"mp/proxy.h\"\n";
241      cpp_types << "// IWYU pragma: no_include \"mp/proxy-io.h\"\n";
242      cpp_types << "#include <" << include_path << ".proxy.h>\n";
243      cpp_types << "#include <" << include_path << ".proxy-types.h> // IWYU pragma: keep\n";
244      cpp_types << "#include <" << PROXY_TYPES << ">\n\n";
245      cpp_types << "namespace mp {\n";
246  
247      std::string guard = output_path;
248      std::ranges::transform(guard, guard.begin(), [](unsigned char c) -> unsigned char {
249          if ('0' <= c && c <= '9') return c;
250          if ('A' <= c && c <= 'Z') return c;
251          if ('a' <= c && c <= 'z') return c - 'a' + 'A';
252          return '_';
253      });
254  
255      std::ofstream inl(output_path + ".proxy-types.h");
256      inl << "// Generated by " PROXY_BIN " from " << src_file << "\n\n";
257      inl << "#ifndef " << guard << "_PROXY_TYPES_H\n";
258      inl << "#define " << guard << "_PROXY_TYPES_H\n\n";
259      inl << "// IWYU pragma: no_include \"mp/proxy.h\"\n";
260      inl << "#include <mp/proxy.h> // IWYU pragma: keep\n";
261      inl << "#include <" << include_path << ".proxy.h> // IWYU pragma: keep\n";
262      for (const auto annotation : file_schema.getProto().getAnnotations()) {
263          if (annotation.getId() == INCLUDE_TYPES_ANNOTATION_ID) {
264              inl << "#include \"" << annotation.getValue().getText() << "\" // IWYU pragma: export\n";
265          }
266      }
267      inl << "namespace mp {\n";
268  
269      std::ofstream h(output_path + ".proxy.h");
270      h << "// Generated by " PROXY_BIN " from " << src_file << "\n\n";
271      h << "#ifndef " << guard << "_PROXY_H\n";
272      h << "#define " << guard << "_PROXY_H\n\n";
273      h << "#include <" << include_path << ".h> // IWYU pragma: keep\n";
274      for (const auto annotation : file_schema.getProto().getAnnotations()) {
275          if (annotation.getId() == INCLUDE_ANNOTATION_ID) {
276              h << "#include \"" << annotation.getValue().getText() << "\" // IWYU pragma: export\n";
277          }
278      }
279      h << "#include <" << PROXY_DECL << ">\n\n";
280      h << "#if defined(__GNUC__)\n";
281      h << "#pragma GCC diagnostic push\n";
282      h << "#if !defined(__has_warning)\n";
283      h << "#pragma GCC diagnostic ignored \"-Wsuggest-override\"\n";
284      h << "#elif __has_warning(\"-Wsuggest-override\")\n";
285      h << "#pragma GCC diagnostic ignored \"-Wsuggest-override\"\n";
286      h << "#endif\n";
287      h << "#endif\n";
288      h << "namespace mp {\n";
289  
290      kj::StringPtr message_namespace;
291      GetAnnotationText(file_schema.getProto(), NAMESPACE_ANNOTATION_ID, &message_namespace);
292  
293      std::string base_name = include_base;
294      const size_t output_slash = base_name.rfind('/');
295      if (output_slash != std::string::npos) {
296          base_name.erase(0, output_slash + 1);
297      }
298  
299      std::ostringstream methods;
300      std::set<kj::StringPtr> accessors_done;
301      std::ostringstream accessors;
302      std::ostringstream dec;
303      std::ostringstream def_server;
304      std::ostringstream def_client;
305      std::ostringstream int_client;
306      std::ostringstream def_types;
307  
308      auto add_accessor = [&](kj::StringPtr name) {
309          if (!accessors_done.insert(name).second) return;
310          const std::string cap = Cap(name);
311          accessors << "struct " << cap << "\n";
312          accessors << "{\n";
313          accessors << "    template<typename S> static auto get(S&& s) -> decltype(s.get" << cap << "()) { return s.get" << cap << "(); }\n";
314          accessors << "    template<typename S> static bool has(S&& s) { return s.has" << cap << "(); }\n";
315          accessors << "    template<typename S, typename A> static void set(S&& s, A&& a) { s.set" << cap
316                    << "(std::forward<A>(a)); }\n";
317          accessors << "    template<typename S, typename... A> static decltype(auto) init(S&& s, A&&... a) { return s.init"
318                    << cap << "(std::forward<A>(a)...); }\n";
319          accessors << "    template<typename S> static bool getWant(S&& s) { return s.getWant" << cap << "(); }\n";
320          accessors << "    template<typename S> static void setWant(S&& s) { s.setWant" << cap << "(true); }\n";
321          accessors << "    template<typename S> static bool getHas(S&& s) { return s.getHas" << cap << "(); }\n";
322          accessors << "    template<typename S> static void setHas(S&& s) { s.setHas" << cap << "(true); }\n";
323          accessors << "};\n";
324      };
325  
326      for (const auto node_nested : file_schema.getProto().getNestedNodes()) {
327          kj::StringPtr node_name = node_nested.getName();
328          const auto& node = file_schema.getNested(node_name);
329          kj::StringPtr proxied_class_type;
330          GetAnnotationText(node.getProto(), WRAP_ANNOTATION_ID, &proxied_class_type);
331  
332          if (node.getProto().isStruct()) {
333              const auto& struc = node.asStruct();
334              std::ostringstream generic_name;
335              generic_name << node_name;
336              dec << "template<";
337              bool first_param = true;
338              for (const auto param : node.getProto().getParameters()) {
339                  if (first_param) {
340                      first_param = false;
341                      generic_name << "<";
342                  } else {
343                      dec << ", ";
344                      generic_name << ", ";
345                  }
346                  dec << "typename " << param.getName();
347                  generic_name << "" << param.getName();
348              }
349              if (!first_param) generic_name << ">";
350              dec << ">\n";
351              dec << "struct ProxyStruct<" << message_namespace << "::" << generic_name.str() << ">\n";
352              dec << "{\n";
353              dec << "    using Struct = " << message_namespace << "::" << generic_name.str() << ";\n";
354              for (const auto field : struc.getFields()) {
355                  auto field_name = field.getProto().getName();
356                  add_accessor(field_name);
357                  dec << "    using " << Cap(field_name) << "Accessor = Accessor<" << base_name
358                      << "_fields::" << Cap(field_name) << ", FIELD_IN | FIELD_OUT";
359                  if (BoxedType(field.getType())) dec << " | FIELD_BOXED";
360                  dec << ">;\n";
361              }
362              dec << "    using Accessors = std::tuple<";
363              size_t i = 0;
364              for (const auto field : struc.getFields()) {
365                  if (AnnotationExists(field.getProto(), SKIP_ANNOTATION_ID)) {
366                      continue;
367                  }
368                  if (i) dec << ", ";
369                  dec << Cap(field.getProto().getName()) << "Accessor";
370                  ++i;
371              }
372              dec << ">;\n";
373              dec << "    static constexpr size_t fields = " << i << ";\n";
374              dec << "};\n";
375  
376              if (proxied_class_type.size()) {
377                  inl << "template<>\n";
378                  inl << "struct ProxyType<" << proxied_class_type << ">\n";
379                  inl << "{\n";
380                  inl << "public:\n";
381                  inl << "    using Struct = " << message_namespace << "::" << node_name << ";\n";
382                  size_t i = 0;
383                  for (const auto field : struc.getFields()) {
384                      if (AnnotationExists(field.getProto(), SKIP_ANNOTATION_ID)) {
385                          continue;
386                      }
387                      auto field_name = field.getProto().getName();
388                      auto member_name = field_name;
389                      GetAnnotationText(field.getProto(), NAME_ANNOTATION_ID, &member_name);
390                      inl << "    static decltype(auto) get(std::integral_constant<size_t, " << i << ">) { return "
391                          << "&" << proxied_class_type << "::" << member_name << "; }\n";
392                      ++i;
393                  }
394                  inl << "    static constexpr size_t fields = " << i << ";\n";
395                  inl << "};\n";
396              }
397          }
398  
399          if (proxied_class_type.size() && node.getProto().isInterface()) {
400              const auto& interface = node.asInterface();
401  
402              std::ostringstream client;
403              client << "template<>\nstruct ProxyClient<" << message_namespace << "::" << node_name << "> final : ";
404              client << "public ProxyClientCustom<" << message_namespace << "::" << node_name << ", "
405                     << proxied_class_type << ">\n{\n";
406              client << "public:\n";
407              client << "    using ProxyClientCustom::ProxyClientCustom;\n";
408              client << "    ~ProxyClient();\n";
409  
410              std::ostringstream server;
411              server << "template<>\nstruct ProxyServer<" << message_namespace << "::" << node_name << "> : public "
412                     << "ProxyServerCustom<" << message_namespace << "::" << node_name << ", " << proxied_class_type
413                     << ">\n{\n";
414              server << "public:\n";
415              server << "    using ProxyServerCustom::ProxyServerCustom;\n";
416              server << "    ~ProxyServer();\n";
417  
418              const std::ostringstream client_construct;
419              const std::ostringstream client_destroy;
420  
421              int method_ordinal = 0;
422              ForEachMethod(interface, [&] (const capnp::InterfaceSchema& method_interface, const capnp::InterfaceSchema::Method& method) {
423                  const kj::StringPtr method_name = method.getProto().getName();
424                  kj::StringPtr proxied_method_name = method_name;
425                  GetAnnotationText(method.getProto(), NAME_ANNOTATION_ID, &proxied_method_name);
426  
427                  const std::string method_prefix = Format() << message_namespace << "::" << method_interface.getShortDisplayName()
428                                                             << "::" << Cap(method_name);
429                  const bool is_construct = method_name == kj::StringPtr{"construct"};
430                  const bool is_destroy = method_name == kj::StringPtr{"destroy"};
431  
432                  struct Field
433                  {
434                      ::capnp::StructSchema::Field param;
435                      bool param_is_set = false;
436                      ::capnp::StructSchema::Field result;
437                      bool result_is_set = false;
438                      int args = 0;
439                      bool retval = false;
440                      bool optional = false;
441                      bool requested = false;
442                      bool skip = false;
443                      kj::StringPtr exception;
444                  };
445  
446                  std::vector<Field> fields;
447                  std::map<kj::StringPtr, int> field_idx; // name -> args index
448                  bool has_result = false;
449  
450                  auto add_field = [&](const ::capnp::StructSchema::Field& schema_field, bool param) {
451                      if (AnnotationExists(schema_field.getProto(), SKIP_ANNOTATION_ID)) {
452                          return;
453                      }
454  
455                      auto field_name = schema_field.getProto().getName();
456                      auto inserted = field_idx.emplace(field_name, fields.size());
457                      if (inserted.second) {
458                          fields.emplace_back();
459                      }
460                      auto& field = fields[inserted.first->second];
461                      if (param) {
462                          field.param = schema_field;
463                          field.param_is_set = true;
464                      } else {
465                          field.result = schema_field;
466                          field.result_is_set = true;
467                      }
468  
469                      if (!param && field_name == kj::StringPtr{"result"}) {
470                          field.retval = true;
471                          has_result = true;
472                      }
473  
474                      GetAnnotationText(schema_field.getProto(), EXCEPTION_ANNOTATION_ID, &field.exception);
475  
476                      int32_t count = 1;
477                      if (!GetAnnotationInt32(schema_field.getProto(), COUNT_ANNOTATION_ID, &count)) {
478                          if (schema_field.getType().isStruct()) {
479                              GetAnnotationInt32(schema_field.getType().asStruct().getProto(),
480                                      COUNT_ANNOTATION_ID, &count);
481                          } else if (schema_field.getType().isInterface()) {
482                              GetAnnotationInt32(schema_field.getType().asInterface().getProto(),
483                                      COUNT_ANNOTATION_ID, &count);
484                          }
485                      }
486  
487  
488                      if (inserted.second && !field.retval && !field.exception.size()) {
489                          field.args = count;
490                      }
491                  };
492  
493                  for (const auto schema_field : method.getParamType().getFields()) {
494                      add_field(schema_field, true);
495                  }
496                  for (const auto schema_field : method.getResultType().getFields()) {
497                      add_field(schema_field, false);
498                  }
499                  for (auto& field : field_idx) {
500                      auto has_field = field_idx.find("has" + Cap(field.first));
501                      if (has_field != field_idx.end()) {
502                          fields[has_field->second].skip = true;
503                          fields[field.second].optional = true;
504                      }
505                      auto want_field = field_idx.find("want" + Cap(field.first));
506                      if (want_field != field_idx.end() && fields[want_field->second].param_is_set) {
507                          fields[want_field->second].skip = true;
508                          fields[field.second].requested = true;
509                      }
510                  }
511  
512                  if (!is_construct && !is_destroy && (&method_interface == &interface)) {
513                      methods << "template<>\n";
514                      methods << "struct ProxyMethod<" << method_prefix << "Params>\n";
515                      methods << "{\n";
516                      methods << "    static constexpr auto impl = &" << proxied_class_type
517                              << "::" << proxied_method_name << ";\n";
518                      methods << "};\n\n";
519                  }
520  
521                  std::ostringstream client_args;
522                  std::ostringstream client_invoke;
523                  std::ostringstream server_invoke_start;
524                  std::ostringstream server_invoke_end;
525                  int argc = 0;
526                  for (const auto& field : fields) {
527                      if (field.skip) continue;
528  
529                      const auto& f = field.param_is_set ? field.param : field.result;
530                      auto field_name = f.getProto().getName();
531                      auto field_type = f.getType();
532  
533                      std::ostringstream field_flags;
534                      if (!field.param_is_set) {
535                          field_flags << "FIELD_OUT";
536                      } else if (field.result_is_set) {
537                          field_flags << "FIELD_IN | FIELD_OUT";
538                      } else {
539                          field_flags << "FIELD_IN";
540                      }
541                      if (field.optional) field_flags << " | FIELD_OPTIONAL";
542                      if (field.requested) field_flags << " | FIELD_REQUESTED";
543                      if (BoxedType(field_type)) field_flags << " | FIELD_BOXED";
544  
545                      add_accessor(field_name);
546  
547                      std::ostringstream fwd_args;
548                      for (int i = 0; i < field.args; ++i) {
549                          if (argc > 0) client_args << ",";
550  
551                          // Add to client method parameter list.
552                          client_args << "M" << method_ordinal << "::Param<" << argc << "> " << field_name;
553                          if (field.args > 1) client_args << i;
554  
555                          // Add to MakeClientParam argument list using Fwd helper for perfect forwarding.
556                          if (i > 0) fwd_args << ", ";
557                          fwd_args << "M" << method_ordinal << "::Fwd<" << argc << ">(" << field_name;
558                          if (field.args > 1) fwd_args << i;
559                          fwd_args << ")";
560  
561                          ++argc;
562                      }
563                      client_invoke << ", ";
564  
565                      if (field.exception.size()) {
566                          client_invoke << "ClientException<" << field.exception << ", ";
567                      } else {
568                          client_invoke << "MakeClientParam<";
569                      }
570  
571                      client_invoke << "Accessor<" << base_name << "_fields::" << Cap(field_name) << ", "
572                                    << field_flags.str() << ">>(";
573  
574                      if (field.retval) {
575                          client_invoke << field_name;
576                      } else {
577                          client_invoke << fwd_args.str();
578                      }
579                      client_invoke << ")";
580  
581                      if (field.exception.size()) {
582                          server_invoke_start << "Make<ServerExcept, " << field.exception;
583                      } else if (field.retval) {
584                          server_invoke_start << "Make<ServerRet";
585                      } else {
586                          server_invoke_start << "MakeServerField<" << field.args;
587                      }
588                      server_invoke_start << ", Accessor<" << base_name << "_fields::" << Cap(field_name) << ", "
589                                          << field_flags.str() << ">>(";
590                      server_invoke_end << ")";
591                  }
592  
593                  const std::string static_str{is_construct || is_destroy ? "static " : ""};
594                  const std::string super_str{is_construct || is_destroy ? "Super& super" : ""};
595                  const std::string self_str{is_construct || is_destroy ? "super" : "*this"};
596  
597                  client << "    using M" << method_ordinal << " = ProxyClientMethodTraits<" << method_prefix
598                         << "Params>;\n";
599                  client << "    " << static_str << "typename M" << method_ordinal << "::Result " << method_name << "("
600                         << super_str << client_args.str() << ")";
601                  client << ";\n";
602                  def_client << "ProxyClient<" << message_namespace << "::" << node_name << ">::M" << method_ordinal
603                             << "::Result ProxyClient<" << message_namespace << "::" << node_name << ">::" << method_name
604                             << "(" << super_str << client_args.str() << ") {\n";
605                  if (has_result) {
606                      def_client << "    typename M" << method_ordinal << "::Result result;\n";
607                  }
608                  def_client << "    clientInvoke(" << self_str << ", &" << message_namespace << "::" << node_name
609                             << "::Client::" << method_name << "Request" << client_invoke.str() << ");\n";
610                  if (has_result) def_client << "    return result;\n";
611                  def_client << "}\n";
612  
613                  server << "    kj::Promise<void> " << method_name << "(" << Cap(method_name)
614                         << "Context call_context) override;\n";
615  
616                  def_server << "kj::Promise<void> ProxyServer<" << message_namespace << "::" << node_name
617                             << ">::" << method_name << "(" << Cap(method_name)
618                             << "Context call_context) {\n"
619                                "    return serverInvoke(*this, call_context, "
620                             << server_invoke_start.str();
621                  if (is_destroy) {
622                      def_server << "ServerDestroy()";
623                  } else {
624                      def_server << "ServerCall()";
625                  }
626                  def_server << server_invoke_end.str() << ");\n}\n";
627                  ++method_ordinal;
628              });
629  
630              client << "};\n";
631              server << "};\n";
632              dec << "\n" << client.str() << "\n" << server.str() << "\n";
633              KJ_IF_MAYBE(bracket, proxied_class_type.findFirst('<')) {
634                // Skip ProxyType definition for complex type expressions which
635                // could lead to duplicate definitions. They can be defined
636                // manually if actually needed.
637              } else {
638                dec << "template<>\nstruct ProxyType<" << proxied_class_type << ">\n{\n";
639                dec << "    using Type = " << proxied_class_type << ";\n";
640                dec << "    using Message = " << message_namespace << "::" << node_name << ";\n";
641                dec << "    using Client = ProxyClient<Message>;\n";
642                dec << "    using Server = ProxyServer<Message>;\n";
643                dec << "};\n";
644                int_client << "ProxyTypeRegister t" << node_nested.getId() << "{TypeList<" << proxied_class_type << ">{}};\n";
645              }
646              def_types << "ProxyClient<" << message_namespace << "::" << node_name
647                        << ">::~ProxyClient() { clientDestroy(*this); " << client_destroy.str() << " }\n";
648              def_types << "ProxyServer<" << message_namespace << "::" << node_name
649                        << ">::~ProxyServer() { serverDestroy(*this); }\n";
650          }
651      }
652  
653      h << methods.str() << "namespace " << base_name << "_fields {\n"
654        << accessors.str() << "} // namespace " << base_name << "_fields\n"
655        << dec.str();
656  
657      cpp_server << def_server.str();
658      cpp_server << "} // namespace mp\n";
659  
660      cpp_client << def_client.str();
661      cpp_client << "namespace {\n" << int_client.str() << "} // namespace\n";
662      cpp_client << "} // namespace mp\n";
663  
664      cpp_types << def_types.str();
665      cpp_types << "} // namespace mp\n";
666  
667      inl << "} // namespace mp\n";
668      inl << "#endif\n";
669  
670      h << "} // namespace mp\n";
671      h << "#if defined(__GNUC__)\n";
672      h << "#pragma GCC diagnostic pop\n";
673      h << "#endif\n";
674      h << "#endif\n";
675  }
676  
677  int main(int argc, char** argv)
678  {
679      if (argc < 3) {
680          std::cerr << "Usage: " << PROXY_BIN << " SRC_PREFIX INCLUDE_PREFIX SRC_FILE [IMPORT_PATH...]\n";
681          exit(1);
682      }
683      std::vector<kj::StringPtr> import_paths;
684      std::vector<kj::Own<const kj::ReadableDirectory>> import_dirs;
685      auto fs = kj::newDiskFilesystem();
686      auto cwd = fs->getCurrentPath();
687      kj::Own<const kj::ReadableDirectory> src_dir;
688      KJ_IF_MAYBE(dir, fs->getRoot().tryOpenSubdir(cwd.evalNative(argv[1]))) {
689          src_dir = kj::mv(*dir);
690      } else {
691          throw std::runtime_error(std::string("Failed to open src_prefix prefix directory: ") + argv[1]);
692      }
693      for (int i = 4; i < argc; ++i) {
694          KJ_IF_MAYBE(dir, fs->getRoot().tryOpenSubdir(cwd.evalNative(argv[i]))) {
695              import_paths.emplace_back(argv[i]);
696              import_dirs.emplace_back(kj::mv(*dir));
697          } else {
698              throw std::runtime_error(std::string("Failed to open import directory: ") + argv[i]);
699          }
700      }
701      for (const char* path : {CMAKE_INSTALL_PREFIX "/include", capnp_PREFIX "/include"}) {
702          KJ_IF_MAYBE(dir, fs->getRoot().tryOpenSubdir(cwd.evalNative(path))) {
703              import_paths.emplace_back(path);
704              import_dirs.emplace_back(kj::mv(*dir));
705          }
706          // No exception thrown if _PREFIX directories do not exist
707      }
708      Generate(argv[1], argv[2], argv[3], import_paths, *src_dir, import_dirs);
709      return 0;
710  }