diff --git a/mypyc/codegen/emitmodule.py b/mypyc/codegen/emitmodule.py index ed44c0a40555b..c3cd06c8c7010 100644 --- a/mypyc/codegen/emitmodule.py +++ b/mypyc/codegen/emitmodule.py @@ -1282,7 +1282,10 @@ def declare_global( def declare_internal_globals(self, module_name: str, emitter: Emitter) -> None: static_name = emitter.static_name("globals", module_name) - self.declare_global("PyObject *", static_name) + if static_name not in self.context.declarations: + self.context.declarations[static_name] = HeaderDeclaration( + f"PyObject *{static_name};", needs_export=True + ) def module_internal_static_name(self, module_name: str, emitter: Emitter) -> str: return emitter.static_name(module_name + "__internal", None, prefix=MODULE_PREFIX) diff --git a/mypyc/irbuild/builder.py b/mypyc/irbuild/builder.py index 30d117e42c71b..3e27cd91b99c9 100644 --- a/mypyc/irbuild/builder.py +++ b/mypyc/irbuild/builder.py @@ -252,6 +252,10 @@ def __init__( self.can_borrow = False + # When set, load_globals_dict uses this module instead of self.module_name. + # Used by generate_attr_defaults_init for cross-module inherited defaults. + self.globals_lookup_module: str | None = None + # High-level control def set_module(self, module_name: str, module_path: str) -> None: @@ -1422,7 +1426,8 @@ def load_global_str(self, name: str, line: int) -> Value: return self.primitive_op(dict_get_item_op, [_globals, reg], line) def load_globals_dict(self) -> Value: - return self.add(LoadStatic(dict_rprimitive, "globals", self.module_name)) + module = self.globals_lookup_module or self.module_name + return self.add(LoadStatic(dict_rprimitive, "globals", module)) def load_module_attr_by_fullname(self, fullname: str, line: int) -> Value: module, _, name = fullname.rpartition(".") diff --git a/mypyc/irbuild/classdef.py b/mypyc/irbuild/classdef.py index 03b24cefb7103..bfbe7f34af2e3 100644 --- a/mypyc/irbuild/classdef.py +++ b/mypyc/irbuild/classdef.py @@ -713,7 +713,7 @@ def add_non_ext_class_attr( def find_attr_initializers( builder: IRBuilder, cdef: ClassDef, skip: Callable[[str, AssignmentStmt], bool] | None = None -) -> tuple[set[str], list[AssignmentStmt]]: +) -> tuple[set[str], list[tuple[AssignmentStmt, str]]]: """Find initializers of attributes in a class body. If provided, the skip arg should be a callable which will return whether @@ -728,7 +728,7 @@ def find_attr_initializers( # Pull out all assignments in classes in the mro so we can initialize them # TODO: Support nested statements - default_assignments = [] + default_assignments: list[tuple[AssignmentStmt, str]] = [] for info in reversed(cdef.info.mro): if info not in builder.mapper.type_to_ir: continue @@ -763,13 +763,13 @@ def find_attr_initializers( continue attrs_with_defaults.add(name) - default_assignments.append(stmt) + default_assignments.append((stmt, info.module_name)) return attrs_with_defaults, default_assignments def generate_attr_defaults_init( - builder: IRBuilder, cdef: ClassDef, default_assignments: list[AssignmentStmt] + builder: IRBuilder, cdef: ClassDef, default_assignments: list[tuple[AssignmentStmt, str]] ) -> None: """Generate an initialization method for default attr values (from class vars).""" if not default_assignments: @@ -780,14 +780,23 @@ def generate_attr_defaults_init( with builder.enter_method(cls, "__mypyc_defaults_setup", bool_rprimitive): self_var = builder.self() - for stmt in default_assignments: + for stmt, origin_module in default_assignments: lvalue = stmt.lvalues[0] assert isinstance(lvalue, NameExpr), lvalue if not stmt.is_final_def and not is_constant(stmt.rvalue): builder.warning("Unsupported default attribute value", stmt.rvalue.line) attr_type = cls.attr_type(lvalue.name) - val = builder.coerce(builder.accept(stmt.rvalue), attr_type, stmt.line) + # When the default comes from a parent in a different module, + # set the globals lookup module so NameExpr references resolve + # against the correct module's globals dict. + builder.globals_lookup_module = ( + origin_module if origin_module != builder.module_name else None + ) + try: + val = builder.coerce(builder.accept(stmt.rvalue), attr_type, stmt.line) + finally: + builder.globals_lookup_module = None init = SetAttr(self_var, lvalue.name, val, stmt.rvalue.line) init.mark_as_initializer() builder.add(init) diff --git a/mypyc/test-data/run-multimodule.test b/mypyc/test-data/run-multimodule.test index 216aed25a5e5b..2c27cdbde9188 100644 --- a/mypyc/test-data/run-multimodule.test +++ b/mypyc/test-data/run-multimodule.test @@ -962,3 +962,27 @@ def translate(b: bytes) -> bytes: [file driver.py] import native assert native.translate(b'ABCD') == b'BBCD' + +[case testCrossModuleAttrDefaults] +from other import Parent + +class Child(Parent): + extra: int = 99 + +def test() -> None: + c = Child() + assert c.config == {"key": "value"} + assert c.extra == 99 + p = Parent() + assert p.config == {"key": "value"} + +[file other.py] +from typing import Dict +MY_DEFAULT: Dict[str, str] = {"key": "value"} + +class Parent: + config: Dict[str, str] = MY_DEFAULT + +[file driver.py] +from native import test +test()