Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,9 @@ pip.parse(
requirements_lock = "tests/python/requirements.txt",
)
use_repo(pip, "rules_mojo_test_deps")

link_hack = use_repo_rule("//mojo/private:link_hack.bzl", "link_hack")

link_hack(
name = "build_bazel_rules_android", # See link_hack.bzl for details
)
5 changes: 5 additions & 0 deletions mojo/mojo_shared_library.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""A rule for creating shared libraries written in Mojo."""

load("//mojo/private:mojo_binary_test.bzl", _mojo_shared_library = "mojo_shared_library")

mojo_shared_library = _mojo_shared_library
15 changes: 15 additions & 0 deletions mojo/private/link_hack.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""This rule hacks around a private API limitation in bazel by re-using the name of a library that is allowed to access the private API.

https://github.com/bazelbuild/bazel/pull/23838
"""

def _link_hack_impl(rctx):
rctx.file("BUILD.bazel", "")
rctx.file("link_hack.bzl", """\
def link_hack(**kwargs):
return cc_common.link(**kwargs)
""")

link_hack = repository_rule(
implementation = _link_hack_impl,
)
73 changes: 60 additions & 13 deletions mojo/private/mojo_binary_test.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
load("@bazel_skylib//lib:paths.bzl", "paths")
load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo")
load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain", "use_cpp_toolchain")
load("@build_bazel_rules_android//:link_hack.bzl", "link_hack") # See link_hack.bzl for details
load("@rules_python//python:py_info.bzl", "PyInfo")
load("//mojo:providers.bzl", "MojoInfo")
load(":utils.bzl", "MOJO_EXTENSIONS", "collect_mojoinfo")
Expand Down Expand Up @@ -57,7 +58,7 @@ def _find_main(name, srcs, main):

fail("Multiple Mojo files provided, but no main file specified. Please set 'main = \"foo.mojo\"' to disambiguate.")

def _mojo_binary_test_implementation(ctx):
def _mojo_binary_test_implementation(ctx, *, shared_library = False):
cc_toolchain = find_cpp_toolchain(ctx)
mojo_toolchain = ctx.toolchains["//:toolchain_type"].mojo_toolchain_info
py_toolchain = ctx.toolchains["@bazel_tools//tools/python:toolchain_type"]
Expand Down Expand Up @@ -132,12 +133,20 @@ def _mojo_binary_test_implementation(ctx):
]),
)]),
)
linking_outputs = cc_common.link(

link_kwargs = {}
if shared_library:
link_kwargs["output_type"] = "dynamic_library"
if ctx.attr.shared_lib_name:
link_kwargs["main_output"] = ctx.actions.declare_file(ctx.attr.shared_lib_name) # Only set if name is not using the default logic

linking_outputs = link_hack(
actions = ctx.actions,
feature_configuration = feature_configuration,
cc_toolchain = cc_toolchain,
linking_contexts = [object_linking_context] + [dep[CcInfo].linking_context for dep in (ctx.attr.deps + mojo_toolchain.implicit_deps) if CcInfo in dep],
name = ctx.label.name,
**link_kwargs
)

data = ctx.attr.data
Expand All @@ -150,6 +159,7 @@ def _mojo_binary_test_implementation(ctx):

# Collect transitive shared libraries that must exist at runtime
python_imports = []
transitive_libraries = []
for target in ctx.attr.deps + mojo_toolchain.implicit_deps:
transitive_runfiles.append(target[DefaultInfo].default_runfiles)

Expand All @@ -164,6 +174,7 @@ def _mojo_binary_test_implementation(ctx):
for linker_input in target[CcInfo].linking_context.linker_inputs.to_list():
for library in linker_input.libraries:
if library.dynamic_library and not library.pic_static_library and not library.static_library:
transitive_libraries.append(depset([library]))
transitive_runfiles.append(ctx.runfiles(transitive_files = depset([library.dynamic_library])))

python_path = ""
Expand Down Expand Up @@ -196,28 +207,64 @@ def _mojo_binary_test_implementation(ctx):
{},
)

return [
DefaultInfo(
executable = linking_outputs.executable,
runfiles = runfiles.merge_all(transitive_runfiles),
),
RunEnvironmentInfo(
environment = runtime_env,
),
]
if shared_library:
return [
DefaultInfo(
executable = linking_outputs.library_to_link.resolved_symlink_dynamic_library,
runfiles = runfiles.merge_all(transitive_runfiles),
),
PyInfo(
imports = depset(["_main/" + paths.dirname(linking_outputs.library_to_link.dynamic_library.short_path)]),
transitive_sources = depset([linking_outputs.library_to_link.dynamic_library]),
),
CcInfo(
linking_context = cc_common.create_linking_context(
linker_inputs = depset([
cc_common.create_linker_input(
owner = ctx.label,
libraries = depset(
[linking_outputs.library_to_link],
transitive = transitive_libraries,
),
),
]),
),
),
]
else:
return [
DefaultInfo(
executable = linking_outputs.executable,
runfiles = runfiles.merge_all(transitive_runfiles),
),
RunEnvironmentInfo(
environment = runtime_env,
),
]

mojo_binary = rule(
implementation = _mojo_binary_test_implementation,
implementation = lambda ctx: _mojo_binary_test_implementation(ctx),
attrs = _ATTRS,
toolchains = _TOOLCHAINS,
fragments = ["cpp"],
executable = True,
)

mojo_test = rule(
implementation = _mojo_binary_test_implementation,
implementation = lambda ctx: _mojo_binary_test_implementation(ctx),
attrs = _ATTRS,
toolchains = _TOOLCHAINS,
fragments = ["cpp"],
test = True,
)

mojo_shared_library = rule(
implementation = lambda ctx: _mojo_binary_test_implementation(ctx, shared_library = True),
attrs = _ATTRS | {
"shared_lib_name": attr.string(
doc = "The name of the shared library to be created.",
),
},
toolchains = _TOOLCHAINS,
fragments = ["cpp"],
)
18 changes: 18 additions & 0 deletions tests/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
load("@bazel_skylib//rules:build_test.bzl", "build_test")
load("//mojo:mojo_binary.bzl", "mojo_binary")
load("//mojo:mojo_shared_library.bzl", "mojo_shared_library")
load("//mojo:mojo_test.bzl", "mojo_test")

mojo_binary(
Expand Down Expand Up @@ -28,3 +29,20 @@ mojo_test(
"//tests/package",
],
)

mojo_shared_library(
name = "shared_library",
srcs = [
"shared_library.mojo",
],
)

mojo_test(
name = "shared_library_test",
srcs = [
"shared_library_test.mojo",
],
deps = [
":shared_library",
],
)
14 changes: 14 additions & 0 deletions tests/python/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
load("@rules_mojo_test_deps//:requirements.bzl", "requirement")
load("@rules_python//python:defs.bzl", "py_test")
load("//mojo:mojo_shared_library.bzl", "mojo_shared_library")
load("//mojo:mojo_test.bzl", "mojo_test")

mojo_test(
Expand All @@ -13,3 +15,15 @@ mojo_test(
requirement("numpy"),
],
)

mojo_shared_library(
name = "python_shared_library",
srcs = ["python_shared_library.mojo"],
shared_lib_name = "python_shared_library.so",
)

py_test(
name = "python_shared_library_test",
srcs = ["python_shared_library_test.py"],
deps = [":python_shared_library"],
)
30 changes: 30 additions & 0 deletions tests/python/python_shared_library.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from os import abort

from python import Python, PythonObject
from python.bindings import PythonModuleBuilder
from python._cpython import PyObjectPtr


@export
fn PyInit_python_shared_library() -> PythonObject:
"""Create a Python module with a function binding for `mojo_count_args`."""

try:
var b = PythonModuleBuilder("python_shared_library")
b.def_py_c_function(
mojo_count_args,
"mojo_count_args",
docstring="Count the provided arguments",
)
return b.finalize()
except e:
return abort[PythonObject](
String("failed to create Python module: ", e)
)


@export
fn mojo_count_args(py_self: PyObjectPtr, args: PyObjectPtr) -> PyObjectPtr:
var cpython = Python().cpython()

return PythonObject(cpython.PyObject_Length(args)).py_object
6 changes: 6 additions & 0 deletions tests/python/python_shared_library_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import python_shared_library

if __name__ == "__main__":
result = python_shared_library.mojo_count_args(1, 2)
assert result == 2
print("Result from Mojo 🔥:", result)
3 changes: 3 additions & 0 deletions tests/shared_library.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
@export
fn foo() -> Int:
return 42
8 changes: 8 additions & 0 deletions tests/shared_library_test.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from sys.ffi import c_int, external_call
from testing import assert_equal

def main():
print("Calling external function...")
result = external_call["foo", c_int]()
print("Result from external function:", result)
assert_equal(result, 42)