Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions onnxscript/_internal/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def build_graph(
*,
opset_imports: dict[str, int] | None = None,
name: str = "subgraph",
parent: GraphBuilder | None = None,
) -> ir.Graph:
"""Build an :class:`ir.Graph` suitable for use as a graph-valued attribute.

Expand Down Expand Up @@ -165,6 +166,10 @@ def build_graph(
opset_imports: Opset version map for the subgraph (e.g.
``{"": 23}``). Defaults to ``{"": 23}`` when *None*.
name: Name of the resulting :class:`ir.Graph`.
parent: Optional parent :class:`GraphBuilder`. When provided, the
sub-builder's ``_root`` points to the root builder of the parent,
so that :meth:`Parameter._realize` registers initializers in the
root (main) graph rather than the subgraph.

Returns:
An :class:`ir.Graph` whose inputs and outputs are populated and whose
Expand All @@ -188,7 +193,9 @@ def build_graph(
for input_name, ts in resolved_inputs:
subgraph.inputs.append(ir.Value(name=input_name, type=ts.type, shape=ts.shape))

sub_builder = GraphBuilder(subgraph)
sub_builder = GraphBuilder(subgraph, parent=parent)
if parent is not None:
sub_builder._scope_stack = list(parent._scope_stack)
trace_outputs = trace_function(sub_builder.op, *subgraph.inputs)
if not isinstance(trace_outputs, Sequence):
trace_outputs = [trace_outputs]
Expand All @@ -209,8 +216,10 @@ def build_graph(
class GraphBuilder:
"""Imperative builder for constructing ONNX IR graphs with automatic constant promotion, type casting, and shape inference."""

def __init__(self, graph: ir.Graph) -> None:
def __init__(self, graph: ir.Graph, parent: GraphBuilder | None = None) -> None:
self._graph = graph
self._parent = parent
self._root: GraphBuilder = parent._root if parent is not None else self

# Get the opset version for "" (default domain) from the graph
if "" not in graph.opset_imports:
Expand Down Expand Up @@ -238,6 +247,16 @@ def opset(self, domain: str, version: int = 1) -> OpBuilder:
def op(self) -> OpBuilder:
return self._op_builder

@property
def parent(self) -> GraphBuilder | None:
"""The parent builder, or None for a top-level builder."""
return self._parent

@property
def root(self) -> GraphBuilder:
"""The root (top-level) builder in the parent chain."""
return self._root

@property
def graph(self) -> ir.Graph:
return self._graph
Expand Down Expand Up @@ -502,6 +521,7 @@ def subgraph(
outputs,
opset_imports=dict(self._graph.opset_imports),
name=name,
parent=self,
)

def call_op(
Expand Down
82 changes: 82 additions & 0 deletions onnxscript/_internal/builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,6 +1061,88 @@ def test_build_graph_custom_name(self):
)
self.assertEqual(graph.name, "loop_body")

def test_build_graph_with_parent(self):
"""build_graph with parent sets root on the sub-builder."""
parent_graph = ir.Graph(
name="main",
inputs=[],
outputs=[],
nodes=[],
opset_imports={"": 23},
)
parent_builder = builder.GraphBuilder(parent_graph)

def body(op, x):
self.assertIs(op.builder.parent, parent_builder)
self.assertIs(op.builder.root, parent_builder)
return op.Identity(x)

builder.build_graph(
body,
inputs=[FLOAT[3]],
outputs=[FLOAT[3]],
parent=parent_builder,
)

def test_subgraph_sets_parent_and_root(self):
"""GraphBuilder.subgraph() sets parent=self on the sub-builder."""
parent_graph = ir.Graph(
name="main",
inputs=[],
outputs=[],
nodes=[],
opset_imports={"": 23},
)
parent_builder = builder.GraphBuilder(parent_graph)

def body(op, x):
self.assertIs(op.builder.parent, parent_builder)
self.assertIs(op.builder.root, parent_builder)
return op.Identity(x)

parent_builder.subgraph(body, inputs=[FLOAT[3]], outputs=[FLOAT[3]])

def test_build_graph_inherits_parent_scope_stack(self):
"""build_graph copies the parent's scope stack so nodes in the subgraph carry scoped names."""
parent_graph = ir.Graph(
name="main",
inputs=[],
outputs=[],
nodes=[],
opset_imports={"": 23},
)
parent_builder = builder.GraphBuilder(parent_graph)
parent_builder.push_module("encoder", "Encoder")
parent_builder.push_module("layers.0", "TransformerBlock")

subgraph = builder.build_graph(
lambda op, x: op.Relu(x),
inputs={"x": FLOAT[3, 4]},
outputs={"y": FLOAT[3, 4]},
parent=parent_builder,
)

# The single node created inside the subgraph should carry the
# parent's scope prefix in its name and metadata.
node = subgraph.node(0)
self.assertIn("encoder", node.name)
self.assertIn("layers.0", node.name)
self.assertIn("encoder", node.metadata_props["namespace"])
self.assertIn("TransformerBlock", node.metadata_props["namespace"])

def test_root_graph_builder_is_its_own_root(self):
"""A top-level GraphBuilder has root == self."""
graph = ir.Graph(
name="main",
inputs=[],
outputs=[],
nodes=[],
opset_imports={"": 23},
)
gb = builder.GraphBuilder(graph)
self.assertIs(gb.root, gb)
self.assertIsNone(gb.parent)


class PartitionInputsAttributesTest(unittest.TestCase):
"""Tests for GraphBuilder._partition_inputs_attributes."""
Expand Down
71 changes: 71 additions & 0 deletions onnxscript/nn/_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,77 @@ def test_realize_qualifies_name(self):
self.assertEqual(value.name, "layer1.bias")
self.assertIn("layer1.bias", graph.initializers)

def test_realize_in_subgraph_registers_in_root(self):
"""Parameter realized inside a subgraph builder is stored in the root graph."""
from onnxscript._internal.builder import GraphBuilder
from onnxscript.onnx_types import FLOAT

root_graph = ir.Graph(
name="main",
inputs=[],
outputs=[],
nodes=[],
opset_imports={"": 23},
)
root_builder = GraphBuilder(root_graph)

p = Parameter([3, 4], name="weight")

def body_fn(op, x):
# Realize param inside a sub-builder context
p._realize(op.builder) # pylint: disable=protected-access
return op.Add(x, x)

_sub_graph = root_builder.subgraph(
body_fn,
inputs=[FLOAT[3, 4]],
outputs=[FLOAT[3, 4]],
)
# Parameter should be in the ROOT graph's initializers, not the subgraph's
self.assertIn("weight", root_graph.initializers)
self.assertIs(root_graph.initializers["weight"], p)
# The subgraph should NOT have the initializer
self.assertNotIn("weight", _sub_graph.initializers)

def test_realize_in_nested_subgraph_registers_in_root(self):
"""Parameter realized in a doubly-nested subgraph goes to the root graph."""
from onnxscript._internal.builder import GraphBuilder, build_graph
from onnxscript.onnx_types import FLOAT

root_graph = ir.Graph(
name="main",
inputs=[],
outputs=[],
nodes=[],
opset_imports={"": 23},
)
root_builder = GraphBuilder(root_graph)

p = Parameter([3], name="bias")

def inner_fn(op, x):
p._realize(op.builder) # pylint: disable=protected-access
return op.Identity(x)

def outer_fn(op, x):
# Build a nested subgraph
build_graph(
inner_fn,
inputs=[FLOAT[3]],
outputs=[FLOAT[3]],
parent=op.builder,
)
return op.Identity(x)

root_builder.subgraph(
outer_fn,
inputs=[FLOAT[3]],
outputs=[FLOAT[3]],
)
# Even through two levels of nesting, param ends up in root
self.assertIn("bias", root_graph.initializers)
self.assertIs(root_graph.initializers["bias"], p)


class ModuleBasicTest(unittest.TestCase):
def test_parameter_auto_registration(self):
Expand Down
11 changes: 9 additions & 2 deletions onnxscript/nn/_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ def dtype(self) -> ir.DataType | None: # type: ignore[override]
def _realize(self, builder: _builder.GraphBuilder) -> Parameter:
"""Qualify the name and register as a graph initializer.

Uses the builder's *root* graph builder to qualify the name and
register the initializer. When the builder is a sub-builder (e.g.
for a Scan body), this ensures the parameter is stored in the
main graph — making it visible as an implicit input to the
subgraph rather than incorrectly placed inside it.

Uses direct assignment to ``graph.initializers[...]`` to skip the
const_value check. Idempotent: subsequent calls are no-ops.
"""
Expand All @@ -73,8 +79,9 @@ def _realize(self, builder: _builder.GraphBuilder) -> Parameter:
"Ensure the Parameter is attached to a Module attribute or otherwise "
"initialized with a name before realization."
)
self_name = self.name = builder._qualify_initializer_name(self_name) # pylint: disable=protected-access
builder.graph.initializers[self_name] = self
root = builder.root
self_name = self.name = root._qualify_initializer_name(self_name) # pylint: disable=protected-access
root.graph.initializers[self_name] = self
self._realized = True
return self

Expand Down
Loading