From 342d46a429ec245f1165f6ab351c249c6a7fd759 Mon Sep 17 00:00:00 2001 From: Keith Smiley Date: Mon, 23 Jun 2025 21:57:05 +0000 Subject: [PATCH] Add mojo_shared_library rule This allows you to produce a shared library from mojo code. This can be used to create native python extensions as seen in the tests. This requires 1 major hack documented in link_hack.bzl to get shared libraries with the right filename for python. --- MODULE.bazel | 6 ++ mojo/mojo_shared_library.bzl | 5 ++ mojo/private/link_hack.bzl | 15 +++++ mojo/private/mojo_binary_test.bzl | 73 ++++++++++++++++++---- tests/BUILD.bazel | 18 ++++++ tests/python/BUILD.bazel | 14 +++++ tests/python/python_shared_library.mojo | 30 +++++++++ tests/python/python_shared_library_test.py | 6 ++ tests/shared_library.mojo | 3 + tests/shared_library_test.mojo | 8 +++ 10 files changed, 165 insertions(+), 13 deletions(-) create mode 100644 mojo/mojo_shared_library.bzl create mode 100644 mojo/private/link_hack.bzl create mode 100644 tests/python/python_shared_library.mojo create mode 100644 tests/python/python_shared_library_test.py create mode 100644 tests/shared_library.mojo create mode 100644 tests/shared_library_test.mojo diff --git a/MODULE.bazel b/MODULE.bazel index 7b24746..b45f69d 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -32,3 +32,9 @@ pip.parse( requirements_lock = "tests/python/requirements.txt", ) use_repo(pip, "rules_mojo_test_deps") + +link_hack = use_repo_rule("//mojo/private:link_hack.bzl", "link_hack") + +link_hack( + name = "build_bazel_rules_android", # See link_hack.bzl for details +) diff --git a/mojo/mojo_shared_library.bzl b/mojo/mojo_shared_library.bzl new file mode 100644 index 0000000..33ee38a --- /dev/null +++ b/mojo/mojo_shared_library.bzl @@ -0,0 +1,5 @@ +"""A rule for creating shared libraries written in Mojo.""" + +load("//mojo/private:mojo_binary_test.bzl", _mojo_shared_library = "mojo_shared_library") + +mojo_shared_library = _mojo_shared_library diff --git a/mojo/private/link_hack.bzl b/mojo/private/link_hack.bzl new file mode 100644 index 0000000..cd72beb --- /dev/null +++ b/mojo/private/link_hack.bzl @@ -0,0 +1,15 @@ +"""This rule hacks around a private API limitation in bazel by re-using the name of a library that is allowed to access the private API. + +https://github.com/bazelbuild/bazel/pull/23838 +""" + +def _link_hack_impl(rctx): + rctx.file("BUILD.bazel", "") + rctx.file("link_hack.bzl", """\ +def link_hack(**kwargs): + return cc_common.link(**kwargs) +""") + +link_hack = repository_rule( + implementation = _link_hack_impl, +) diff --git a/mojo/private/mojo_binary_test.bzl b/mojo/private/mojo_binary_test.bzl index 374e161..c45b13f 100644 --- a/mojo/private/mojo_binary_test.bzl +++ b/mojo/private/mojo_binary_test.bzl @@ -3,6 +3,7 @@ load("@bazel_skylib//lib:paths.bzl", "paths") load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo") load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain", "use_cpp_toolchain") +load("@build_bazel_rules_android//:link_hack.bzl", "link_hack") # See link_hack.bzl for details load("@rules_python//python:py_info.bzl", "PyInfo") load("//mojo:providers.bzl", "MojoInfo") load(":utils.bzl", "MOJO_EXTENSIONS", "collect_mojoinfo") @@ -57,7 +58,7 @@ def _find_main(name, srcs, main): fail("Multiple Mojo files provided, but no main file specified. Please set 'main = \"foo.mojo\"' to disambiguate.") -def _mojo_binary_test_implementation(ctx): +def _mojo_binary_test_implementation(ctx, *, shared_library = False): cc_toolchain = find_cpp_toolchain(ctx) mojo_toolchain = ctx.toolchains["//:toolchain_type"].mojo_toolchain_info py_toolchain = ctx.toolchains["@bazel_tools//tools/python:toolchain_type"] @@ -132,12 +133,20 @@ def _mojo_binary_test_implementation(ctx): ]), )]), ) - linking_outputs = cc_common.link( + + link_kwargs = {} + if shared_library: + link_kwargs["output_type"] = "dynamic_library" + if ctx.attr.shared_lib_name: + link_kwargs["main_output"] = ctx.actions.declare_file(ctx.attr.shared_lib_name) # Only set if name is not using the default logic + + linking_outputs = link_hack( actions = ctx.actions, feature_configuration = feature_configuration, cc_toolchain = cc_toolchain, linking_contexts = [object_linking_context] + [dep[CcInfo].linking_context for dep in (ctx.attr.deps + mojo_toolchain.implicit_deps) if CcInfo in dep], name = ctx.label.name, + **link_kwargs ) data = ctx.attr.data @@ -150,6 +159,7 @@ def _mojo_binary_test_implementation(ctx): # Collect transitive shared libraries that must exist at runtime python_imports = [] + transitive_libraries = [] for target in ctx.attr.deps + mojo_toolchain.implicit_deps: transitive_runfiles.append(target[DefaultInfo].default_runfiles) @@ -164,6 +174,7 @@ def _mojo_binary_test_implementation(ctx): for linker_input in target[CcInfo].linking_context.linker_inputs.to_list(): for library in linker_input.libraries: if library.dynamic_library and not library.pic_static_library and not library.static_library: + transitive_libraries.append(depset([library])) transitive_runfiles.append(ctx.runfiles(transitive_files = depset([library.dynamic_library]))) python_path = "" @@ -196,18 +207,43 @@ def _mojo_binary_test_implementation(ctx): {}, ) - return [ - DefaultInfo( - executable = linking_outputs.executable, - runfiles = runfiles.merge_all(transitive_runfiles), - ), - RunEnvironmentInfo( - environment = runtime_env, - ), - ] + if shared_library: + return [ + DefaultInfo( + executable = linking_outputs.library_to_link.resolved_symlink_dynamic_library, + runfiles = runfiles.merge_all(transitive_runfiles), + ), + PyInfo( + imports = depset(["_main/" + paths.dirname(linking_outputs.library_to_link.dynamic_library.short_path)]), + transitive_sources = depset([linking_outputs.library_to_link.dynamic_library]), + ), + CcInfo( + linking_context = cc_common.create_linking_context( + linker_inputs = depset([ + cc_common.create_linker_input( + owner = ctx.label, + libraries = depset( + [linking_outputs.library_to_link], + transitive = transitive_libraries, + ), + ), + ]), + ), + ), + ] + else: + return [ + DefaultInfo( + executable = linking_outputs.executable, + runfiles = runfiles.merge_all(transitive_runfiles), + ), + RunEnvironmentInfo( + environment = runtime_env, + ), + ] mojo_binary = rule( - implementation = _mojo_binary_test_implementation, + implementation = lambda ctx: _mojo_binary_test_implementation(ctx), attrs = _ATTRS, toolchains = _TOOLCHAINS, fragments = ["cpp"], @@ -215,9 +251,20 @@ mojo_binary = rule( ) mojo_test = rule( - implementation = _mojo_binary_test_implementation, + implementation = lambda ctx: _mojo_binary_test_implementation(ctx), attrs = _ATTRS, toolchains = _TOOLCHAINS, fragments = ["cpp"], test = True, ) + +mojo_shared_library = rule( + implementation = lambda ctx: _mojo_binary_test_implementation(ctx, shared_library = True), + attrs = _ATTRS | { + "shared_lib_name": attr.string( + doc = "The name of the shared library to be created.", + ), + }, + toolchains = _TOOLCHAINS, + fragments = ["cpp"], +) diff --git a/tests/BUILD.bazel b/tests/BUILD.bazel index 1772b2d..eb7f81d 100644 --- a/tests/BUILD.bazel +++ b/tests/BUILD.bazel @@ -1,5 +1,6 @@ load("@bazel_skylib//rules:build_test.bzl", "build_test") load("//mojo:mojo_binary.bzl", "mojo_binary") +load("//mojo:mojo_shared_library.bzl", "mojo_shared_library") load("//mojo:mojo_test.bzl", "mojo_test") mojo_binary( @@ -28,3 +29,20 @@ mojo_test( "//tests/package", ], ) + +mojo_shared_library( + name = "shared_library", + srcs = [ + "shared_library.mojo", + ], +) + +mojo_test( + name = "shared_library_test", + srcs = [ + "shared_library_test.mojo", + ], + deps = [ + ":shared_library", + ], +) diff --git a/tests/python/BUILD.bazel b/tests/python/BUILD.bazel index 7491c8f..4ce810d 100644 --- a/tests/python/BUILD.bazel +++ b/tests/python/BUILD.bazel @@ -1,4 +1,6 @@ load("@rules_mojo_test_deps//:requirements.bzl", "requirement") +load("@rules_python//python:defs.bzl", "py_test") +load("//mojo:mojo_shared_library.bzl", "mojo_shared_library") load("//mojo:mojo_test.bzl", "mojo_test") mojo_test( @@ -13,3 +15,15 @@ mojo_test( requirement("numpy"), ], ) + +mojo_shared_library( + name = "python_shared_library", + srcs = ["python_shared_library.mojo"], + shared_lib_name = "python_shared_library.so", +) + +py_test( + name = "python_shared_library_test", + srcs = ["python_shared_library_test.py"], + deps = [":python_shared_library"], +) diff --git a/tests/python/python_shared_library.mojo b/tests/python/python_shared_library.mojo new file mode 100644 index 0000000..c42b72b --- /dev/null +++ b/tests/python/python_shared_library.mojo @@ -0,0 +1,30 @@ +from os import abort + +from python import Python, PythonObject +from python.bindings import PythonModuleBuilder +from python._cpython import PyObjectPtr + + +@export +fn PyInit_python_shared_library() -> PythonObject: + """Create a Python module with a function binding for `mojo_count_args`.""" + + try: + var b = PythonModuleBuilder("python_shared_library") + b.def_py_c_function( + mojo_count_args, + "mojo_count_args", + docstring="Count the provided arguments", + ) + return b.finalize() + except e: + return abort[PythonObject]( + String("failed to create Python module: ", e) + ) + + +@export +fn mojo_count_args(py_self: PyObjectPtr, args: PyObjectPtr) -> PyObjectPtr: + var cpython = Python().cpython() + + return PythonObject(cpython.PyObject_Length(args)).py_object diff --git a/tests/python/python_shared_library_test.py b/tests/python/python_shared_library_test.py new file mode 100644 index 0000000..0b34216 --- /dev/null +++ b/tests/python/python_shared_library_test.py @@ -0,0 +1,6 @@ +import python_shared_library + +if __name__ == "__main__": + result = python_shared_library.mojo_count_args(1, 2) + assert result == 2 + print("Result from Mojo 🔥:", result) diff --git a/tests/shared_library.mojo b/tests/shared_library.mojo new file mode 100644 index 0000000..a17d353 --- /dev/null +++ b/tests/shared_library.mojo @@ -0,0 +1,3 @@ +@export +fn foo() -> Int: + return 42 diff --git a/tests/shared_library_test.mojo b/tests/shared_library_test.mojo new file mode 100644 index 0000000..88fa5b2 --- /dev/null +++ b/tests/shared_library_test.mojo @@ -0,0 +1,8 @@ +from sys.ffi import c_int, external_call +from testing import assert_equal + +def main(): + print("Calling external function...") + result = external_call["foo", c_int]() + print("Result from external function:", result) + assert_equal(result, 42)