/ toolchains / python.bzl
python.bzl
  1  # toolchains/python.bzl
  2  #
  3  # Python toolchain with nanobind for C++ bindings.
  4  #
  5  # Paths are read from .buckconfig.local [python] section.
  6  # Uses Python from Nix devshell with nanobind pre-installed.
  7  #
  8  # For unwrapped clang, we need explicit stdlib include and library paths.
  9  # Nanobind requires compiling its source files along with user code.
 10  
 11  # Nanobind source files that must be compiled with the extension
 12  NB_SOURCES = [
 13      "src/nb_internals.cpp",
 14      "src/nb_func.cpp",
 15      "src/nb_type.cpp",
 16      "src/nb_enum.cpp",
 17      "src/nb_ndarray.cpp",
 18      "src/nb_static_property.cpp",
 19      "src/nb_ft.cpp",
 20      "src/common.cpp",
 21      "src/error.cpp",
 22      "src/trampoline.cpp",
 23      "src/implicit.cpp",
 24  ]
 25  
 26  def _python_script_impl(ctx: AnalysisContext) -> list[Provider]:
 27      """
 28      Python script rule that can depend on extension modules.
 29      
 30      Creates a wrapper script that sets PYTHONPATH to include extension modules.
 31      """
 32      interpreter = read_root_config("python", "interpreter", "python3")
 33      
 34      # Collect extension .so files from deps
 35      ext_outputs = []
 36      for dep in ctx.attrs.deps:
 37          info = dep[DefaultInfo]
 38          for out in info.default_outputs:
 39              if out.short_path.endswith(".so"):
 40                  ext_outputs.append(out)
 41      
 42      if ext_outputs:
 43          # Create wrapper script that sets PYTHONPATH
 44          wrapper = ctx.actions.declare_output(ctx.attrs.name + "_run.sh")
 45          
 46          # Build wrapper content using cmd_args for proper artifact paths
 47          # Use delimiter="" to avoid newlines between args
 48          wrapper_cmd = cmd_args(delimiter = "")
 49          wrapper_cmd.add("#!/bin/bash\n")
 50          wrapper_cmd.add("# Auto-generated wrapper for " + ctx.attrs.name + "\n")
 51          wrapper_cmd.add("export PYTHONPATH=\"")
 52          for i, ext in enumerate(ext_outputs):
 53              if i > 0:
 54                  wrapper_cmd.add(":")
 55              # Use parent format to get directory of .so file
 56              wrapper_cmd.add(cmd_args(ext, parent = 1))
 57          wrapper_cmd.add("${PYTHONPATH:+:$PYTHONPATH}\"\n")
 58          wrapper_cmd.add("exec " + interpreter + " ")
 59          wrapper_cmd.add(ctx.attrs.main)
 60          wrapper_cmd.add(" \"$@\"\n")
 61          
 62          ctx.actions.write(wrapper, wrapper_cmd, is_executable = True)
 63          
 64          return [
 65              DefaultInfo(default_output = wrapper, other_outputs = ext_outputs),
 66              RunInfo(args = [wrapper]),
 67          ]
 68      else:
 69          return [
 70              DefaultInfo(default_output = ctx.attrs.main),
 71              RunInfo(args = [interpreter, ctx.attrs.main]),
 72          ]
 73  
 74  python_script = rule(
 75      impl = _python_script_impl,
 76      attrs = {
 77          "main": attrs.source(),
 78          "deps": attrs.list(attrs.dep(), default = []),
 79      },
 80  )
 81  
 82  def _nanobind_extension_impl(ctx: AnalysisContext) -> list[Provider]:
 83      """
 84      Build a nanobind C++ extension module.
 85  
 86      Uses clang from the :cxx toolchain and nanobind headers from Nix.
 87      Requires unwrapped clang with explicit stdlib include and library paths.
 88      Compiles nanobind source files along with user code.
 89      """
 90      # Get paths from config
 91      cxx = read_root_config("cxx", "cxx", "clang++")
 92      python_include = read_root_config("python", "python_include", "/usr/include/python3.12")
 93      nanobind_path = read_root_config("python", "nanobind_cmake", "")  # Package root with src/
 94      nanobind_include = read_root_config("python", "nanobind_include", "")
 95  
 96      # Stdlib paths for unwrapped clang
 97      gcc_include = read_root_config("cxx", "gcc_include", "")
 98      gcc_include_arch = read_root_config("cxx", "gcc_include_arch", "")
 99      glibc_include = read_root_config("cxx", "glibc_include", "")
100      clang_resource_dir = read_root_config("cxx", "clang_resource_dir", "")
101  
102      # Library paths for linking
103      gcc_lib = read_root_config("cxx", "gcc_lib", "")
104      gcc_lib_base = read_root_config("cxx", "gcc_lib_base", "")
105      glibc_lib = read_root_config("cxx", "glibc_lib", "")
106  
107      # Output .so file
108      out = ctx.actions.declare_output(ctx.attrs.name + ".so")
109  
110      # Compile flags - must include stdlib paths for unwrapped clang
111      compile_flags = [
112          "-std=c++17",  # nanobind requires C++17
113          "-O2",
114          "-fPIC",
115          "-shared",
116          "-fvisibility=hidden",  # nanobind recommendation
117          "-fno-strict-aliasing",
118      ]
119  
120      # Add stdlib include paths (order matters for unwrapped clang)
121      if gcc_include:
122          compile_flags.extend(["-isystem", gcc_include])
123      if gcc_include_arch:
124          compile_flags.extend(["-isystem", gcc_include_arch])
125      if glibc_include:
126          compile_flags.extend(["-isystem", glibc_include])
127      if clang_resource_dir:
128          compile_flags.extend(["-resource-dir=" + clang_resource_dir])
129  
130      # Python and nanobind includes
131      compile_flags.extend(["-isystem", python_include])
132      if nanobind_include:
133          compile_flags.extend(["-isystem", nanobind_include])
134  
135      # Nanobind's external dependencies (robin_map)
136      if nanobind_path:
137          compile_flags.extend(["-isystem", nanobind_path + "/ext/robin_map/include"])
138  
139      # Link flags for unwrapped clang
140      # -B tells clang where to find CRT startup files (crti.o, crtbeginS.o, etc.)
141      link_flags = []
142      if glibc_lib:
143          link_flags.extend(["-B" + glibc_lib, "-L" + glibc_lib])
144      if gcc_lib:
145          link_flags.extend(["-B" + gcc_lib, "-L" + gcc_lib])
146      if gcc_lib_base:
147          link_flags.extend(["-L" + gcc_lib_base])
148      link_flags.extend(["-lstdc++", "-lm", "-ldl", "-lpthread"])
149  
150      # Collect all source files: user sources + nanobind sources
151      all_srcs = [src for src in ctx.attrs.srcs]
152  
153      # Add nanobind source files if path is configured
154      nb_srcs = []
155      if nanobind_path:
156          nb_srcs = [nanobind_path + "/" + src for src in NB_SOURCES]
157  
158      # Build command
159      cmd = cmd_args([
160          cxx,
161      ] + compile_flags + link_flags + [
162          "-o", out.as_output(),
163      ] + all_srcs + nb_srcs)
164  
165      ctx.actions.run(cmd, category = "nanobind_compile")
166  
167      return [
168          DefaultInfo(default_output = out),
169      ]
170  
171  nanobind_extension = rule(
172      impl = _nanobind_extension_impl,
173      attrs = {
174          "srcs": attrs.list(attrs.source()),
175          "deps": attrs.list(attrs.dep(), default = []),
176          "compiler_flags": attrs.list(attrs.string(), default = []),
177      },
178  )
179  
180  # Import NvLibraryInfo from nv.bzl for type checking
181  load("@toolchains//:nv.bzl", "NvLibraryInfo")
182  
183  def _pybind11_extension_impl(ctx: AnalysisContext) -> list[Provider]:
184      """
185      Build a pybind11 C++ extension module.
186      
187      pybind11 is header-only so simpler than nanobind.
188      Supports nv_deps for linking CUDA libraries.
189      """
190      # Get paths from config
191      cxx = read_root_config("cxx", "cxx", "clang++")
192      python_include = read_root_config("python", "python_include", "/usr/include/python3.12")
193      pybind11_include = read_root_config("python", "pybind11_include", "")
194  
195      # Stdlib paths for unwrapped clang
196      gcc_include = read_root_config("cxx", "gcc_include", "")
197      gcc_include_arch = read_root_config("cxx", "gcc_include_arch", "")
198      glibc_include = read_root_config("cxx", "glibc_include", "")
199      clang_resource_dir = read_root_config("cxx", "clang_resource_dir", "")
200  
201      # Library paths for linking
202      gcc_lib = read_root_config("cxx", "gcc_lib", "")
203      gcc_lib_base = read_root_config("cxx", "gcc_lib_base", "")
204      glibc_lib = read_root_config("cxx", "glibc_lib", "")
205  
206      # NVIDIA SDK for CUDA dependencies
207      nvidia_sdk_lib = read_root_config("nv", "nvidia_sdk_lib", "")
208      nvidia_sdk_include = read_root_config("nv", "nvidia_sdk_include", "")
209  
210      # Output .so file
211      out = ctx.actions.declare_output(ctx.attrs.name + ".so")
212  
213      # Compile flags
214      compile_flags = [
215          "-std=c++17",
216          "-O2",
217          "-fPIC",
218          "-shared",
219          "-fvisibility=hidden",
220      ]
221  
222      # Add stdlib include paths
223      if gcc_include:
224          compile_flags.extend(["-isystem", gcc_include])
225      if gcc_include_arch:
226          compile_flags.extend(["-isystem", gcc_include_arch])
227      if glibc_include:
228          compile_flags.extend(["-isystem", glibc_include])
229      if clang_resource_dir:
230          compile_flags.extend(["-resource-dir=" + clang_resource_dir])
231  
232      # Python and pybind11 includes
233      compile_flags.extend(["-isystem", python_include])
234      if pybind11_include:
235          compile_flags.extend(["-isystem", pybind11_include])
236  
237      # Add CUDA include if we have nv_deps
238      if ctx.attrs.nv_deps and nvidia_sdk_include:
239          compile_flags.extend(["-isystem", nvidia_sdk_include])
240  
241      # Collect objects and headers from nv_deps
242      nv_objects = []
243      nv_headers = []
244      for dep in ctx.attrs.nv_deps:
245          if NvLibraryInfo in dep:
246              nv_info = dep[NvLibraryInfo]
247              nv_objects.extend(nv_info.objects)
248              nv_headers.extend(nv_info.headers)
249              # Add include path for headers (use the source directory)
250              if nv_info.headers:
251                  # Headers are source artifacts, add their parent dir to includes
252                  compile_flags.extend(["-I", "."])
253  
254      # Link flags
255      # -B tells clang where to find CRT startup files (crti.o, crtbeginS.o, etc.)
256      link_flags = []
257      if glibc_lib:
258          link_flags.extend(["-B" + glibc_lib, "-L" + glibc_lib])
259      if gcc_lib:
260          link_flags.extend(["-B" + gcc_lib, "-L" + gcc_lib])
261      if gcc_lib_base:
262          link_flags.extend(["-L" + gcc_lib_base])
263      link_flags.extend(["-lstdc++", "-lm", "-ldl", "-lpthread"])
264  
265      # Add CUDA runtime library if we have nv_deps
266      if ctx.attrs.nv_deps and nvidia_sdk_lib:
267          link_flags.extend([
268              "-L" + nvidia_sdk_lib,
269              "-Wl,-rpath," + nvidia_sdk_lib,
270              "-lcudart",
271          ])
272  
273      # Build command: compile user sources and link with nv objects
274      cmd = cmd_args([
275          cxx,
276      ] + compile_flags + link_flags + [
277          "-o", out.as_output(),
278      ] + [src for src in ctx.attrs.srcs] + nv_objects)
279  
280      ctx.actions.run(cmd, category = "pybind11_compile")
281  
282      return [
283          DefaultInfo(default_output = out),
284      ]
285  
286  pybind11_extension = rule(
287      impl = _pybind11_extension_impl,
288      attrs = {
289          "srcs": attrs.list(attrs.source()),
290          "deps": attrs.list(attrs.dep(), default = []),
291          "nv_deps": attrs.list(attrs.dep(), default = []),
292          "compiler_flags": attrs.list(attrs.string(), default = []),
293      },
294  )