Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mypyc/codegen/emitmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,7 +1282,10 @@ def declare_global(

def declare_internal_globals(self, module_name: str, emitter: Emitter) -> None:
static_name = emitter.static_name("globals", module_name)
self.declare_global("PyObject *", static_name)
if static_name not in self.context.declarations:
self.context.declarations[static_name] = HeaderDeclaration(
f"PyObject *{static_name};", needs_export=True
)

def module_internal_static_name(self, module_name: str, emitter: Emitter) -> str:
return emitter.static_name(module_name + "__internal", None, prefix=MODULE_PREFIX)
Expand Down
7 changes: 6 additions & 1 deletion mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ def __init__(

self.can_borrow = False

# When set, load_globals_dict uses this module instead of self.module_name.
# Used by generate_attr_defaults_init for cross-module inherited defaults.
self.globals_lookup_module: str | None = None

# High-level control

def set_module(self, module_name: str, module_path: str) -> None:
Expand Down Expand Up @@ -1422,7 +1426,8 @@ def load_global_str(self, name: str, line: int) -> Value:
return self.primitive_op(dict_get_item_op, [_globals, reg], line)

def load_globals_dict(self) -> Value:
return self.add(LoadStatic(dict_rprimitive, "globals", self.module_name))
module = self.globals_lookup_module or self.module_name
return self.add(LoadStatic(dict_rprimitive, "globals", module))

def load_module_attr_by_fullname(self, fullname: str, line: int) -> Value:
module, _, name = fullname.rpartition(".")
Expand Down
21 changes: 15 additions & 6 deletions mypyc/irbuild/classdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ def add_non_ext_class_attr(

def find_attr_initializers(
builder: IRBuilder, cdef: ClassDef, skip: Callable[[str, AssignmentStmt], bool] | None = None
) -> tuple[set[str], list[AssignmentStmt]]:
) -> tuple[set[str], list[tuple[AssignmentStmt, str]]]:
"""Find initializers of attributes in a class body.

If provided, the skip arg should be a callable which will return whether
Expand All @@ -728,7 +728,7 @@ def find_attr_initializers(

# Pull out all assignments in classes in the mro so we can initialize them
# TODO: Support nested statements
default_assignments = []
default_assignments: list[tuple[AssignmentStmt, str]] = []
for info in reversed(cdef.info.mro):
if info not in builder.mapper.type_to_ir:
continue
Expand Down Expand Up @@ -763,13 +763,13 @@ def find_attr_initializers(
continue

attrs_with_defaults.add(name)
default_assignments.append(stmt)
default_assignments.append((stmt, info.module_name))

return attrs_with_defaults, default_assignments


def generate_attr_defaults_init(
builder: IRBuilder, cdef: ClassDef, default_assignments: list[AssignmentStmt]
builder: IRBuilder, cdef: ClassDef, default_assignments: list[tuple[AssignmentStmt, str]]
) -> None:
"""Generate an initialization method for default attr values (from class vars)."""
if not default_assignments:
Expand All @@ -780,14 +780,23 @@ def generate_attr_defaults_init(

with builder.enter_method(cls, "__mypyc_defaults_setup", bool_rprimitive):
self_var = builder.self()
for stmt in default_assignments:
for stmt, origin_module in default_assignments:
lvalue = stmt.lvalues[0]
assert isinstance(lvalue, NameExpr), lvalue
if not stmt.is_final_def and not is_constant(stmt.rvalue):
builder.warning("Unsupported default attribute value", stmt.rvalue.line)

attr_type = cls.attr_type(lvalue.name)
val = builder.coerce(builder.accept(stmt.rvalue), attr_type, stmt.line)
# When the default comes from a parent in a different module,
# set the globals lookup module so NameExpr references resolve
# against the correct module's globals dict.
builder.globals_lookup_module = (
origin_module if origin_module != builder.module_name else None
)
try:
val = builder.coerce(builder.accept(stmt.rvalue), attr_type, stmt.line)
finally:
builder.globals_lookup_module = None
init = SetAttr(self_var, lvalue.name, val, stmt.rvalue.line)
init.mark_as_initializer()
builder.add(init)
Expand Down
24 changes: 24 additions & 0 deletions mypyc/test-data/run-multimodule.test
Original file line number Diff line number Diff line change
Expand Up @@ -962,3 +962,27 @@ def translate(b: bytes) -> bytes:
[file driver.py]
import native
assert native.translate(b'ABCD') == b'BBCD'

[case testCrossModuleAttrDefaults]
from other import Parent

class Child(Parent):
extra: int = 99

def test() -> None:
c = Child()
assert c.config == {"key": "value"}
assert c.extra == 99
p = Parent()
assert p.config == {"key": "value"}

[file other.py]
from typing import Dict
MY_DEFAULT: Dict[str, str] = {"key": "value"}

class Parent:
config: Dict[str, str] = MY_DEFAULT

[file driver.py]
from native import test
test()
Loading