From 5324aadfe4018d9e16cf1eee3dc3eff9aeda2efa Mon Sep 17 00:00:00 2001 From: Aharrypotter <62729549+Aharrypotter@users.noreply.github.com> Date: Sat, 28 Mar 2026 13:07:44 +0800 Subject: [PATCH 1/4] [Relax][ONNX] Add Optional and MatMulInteger16 frontend support --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 111 ++++++++- tests/python/relax/test_frontend_onnx.py | 228 ++++++++++++++++++ 2 files changed, 330 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 3dc575ae778c..e91a2c763d67 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -317,6 +317,37 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.op.matmul(inputs[0], inputs[1]) +class MatMulInteger16(OnnxOpConverter): + """Converts an ONNX MatMulInteger16 node into an equivalent Relax expression.""" + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + if len(inputs) != 2: + raise ValueError(f"MatMulInteger16 expects two inputs, but got {len(inputs)}") + a, b = inputs + valid_types = ["int16", "uint16"] + if a.struct_info.dtype not in valid_types: + raise ValueError( + "MatMulInteger16 expects input A to have int16 or uint16 dtype, " + f"but got {a.struct_info.dtype}" + ) + if b.struct_info.dtype not in valid_types: + raise ValueError( + "MatMulInteger16 expects input B to have int16 or uint16 dtype, " + f"but got {b.struct_info.dtype}" + ) + + out_dtype = ( + "uint32" + if a.struct_info.dtype == "uint16" and b.struct_info.dtype == "uint16" + else "int32" + ) + return relax.op.matmul( + relax.op.astype(a, out_dtype), + relax.op.astype(b, out_dtype), + ) + + def _to_numpy(x): if isinstance(x, relax.PrimValue): x = x.value @@ -327,6 +358,19 @@ def _to_numpy(x): return x.data.numpy() +class _EmptyOptional: + """Sentinel object used to represent an empty ONNX Optional value.""" + + def __init__(self, type_proto: onnx.onnx_ml_pb2.TypeProto): + self.type_proto = type_proto + + +def _is_empty_optional(value: Any) -> bool: + """Returns whether the given value represents an empty ONNX Optional.""" + + return isinstance(value, _EmptyOptional) + + class BinaryBase(OnnxOpConverter): """Converts an onnx BinaryBase node into an equivalent Relax expression.""" @@ -3565,6 +3609,50 @@ def _impl_v1(cls, bb, inputs, attr, params): ) +class Optional_(OnnxOpConverter): + """Operator converter for Optional.""" + + @classmethod + def _impl_v15(cls, bb, inputs, attr, params): + if len(inputs) > 1: + raise ValueError(f"Optional accepts at most one input, but got {len(inputs)}") + if len(inputs) == 0 or inputs[0] is None: + if "type" not in attr: + raise ValueError("Optional without an input must specify the type attribute.") + return _EmptyOptional(attr["type"]) + return inputs[0] + + _impl_v18 = _impl_v15 + + +class OptionalHasElement(OnnxOpConverter): + """Operator converter for OptionalHasElement.""" + + @classmethod + def _impl_v15(cls, bb, inputs, attr, params): + if len(inputs) > 1: + raise ValueError(f"OptionalHasElement accepts at most one input, but got {len(inputs)}") + if len(inputs) == 0 or inputs[0] is None or _is_empty_optional(inputs[0]): + return relax.const(False, dtype="bool") + return relax.const(True, dtype="bool") + + _impl_v18 = _impl_v15 + + +class OptionalGetElement(OnnxOpConverter): + """Operator converter for OptionalGetElement.""" + + @classmethod + def _impl_v15(cls, bb, inputs, attr, params): + if len(inputs) != 1: + raise ValueError(f"OptionalGetElement expects one input, but got {len(inputs)}") + if inputs[0] is None or _is_empty_optional(inputs[0]): + raise ValueError("OptionalGetElement cannot access an empty optional.") + return inputs[0] + + _impl_v18 = _impl_v15 + + class SequenceConstruct(OnnxOpConverter): """Operator converter for sequence construction op.""" @@ -3898,9 +3986,9 @@ def _impl_v1(cls, bb, inputs, attr, params): def _get_convert_map(): return { # defs/experimental - # "Optional": Optional_, - # "OptionalHasElement": OptionalHasElement, - # "OptionalGetElement": OptionalGetElement, + "Optional": Optional_, + "OptionalHasElement": OptionalHasElement, + "OptionalGetElement": OptionalGetElement, # Binary operators "Add": Add, "Sub": Sub, @@ -3971,7 +4059,7 @@ def _get_convert_map(): "Gemm": Gemm, "MatMul": MatMul, # "MatMulInteger": MatMulInteger, - # "MatMulInteger16": MatMulInteger16, + "MatMulInteger16": MatMulInteger16, "Reshape": Reshape, "Sigmoid": Sigmoid, "Softmax": Softmax, @@ -4281,6 +4369,8 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto): "Squeeze", ] return_tuple_ops = [ + "Optional", + "OptionalGetElement", "SequenceConstruct", "SequenceEmpty", "SequenceErase", @@ -4299,13 +4389,16 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto): try: op = self._convert_operator(op_name, inputs, attr, self.opset) # Create struct information for the new operator. - op = self.bb.normalize(op) + if isinstance(op, relax.Expr): + op = self.bb.normalize(op) except TVMError as err: print(f"Error converting operator {op_name}, with inputs: {inputs}") raise err if op_name in return_tuple_ops: outputs_num = 1 + elif _is_empty_optional(op): + outputs_num = 1 elif not isinstance(op, relax.Tuple): if isinstance(op.struct_info, relax.TupleStructInfo): # This is a var bound to a tuple. We need to unpack it and create @@ -4351,11 +4444,11 @@ def _parse_attr(self, attr_proto: onnx.onnx_ml_pb2.AttributeProto) -> dict[str, if list(getattr(a, f)): assert a.name not in attrs, "Only one type of attr is allowed" attrs[a.name] = tuple(getattr(a, f)) - for f in ["t"]: - if a.HasField(f): + for f in ["t", "tp"]: + if hasattr(a, f) and a.HasField(f): attrs[a.name] = getattr(a, f) - for f in ["tensors"]: - if list(getattr(a, f)): + for f in ["tensors", "type_protos"]: + if hasattr(a, f) and list(getattr(a, f)): assert a.name not in attrs, "Only one type of attr is allowed" attrs[a.name] = tuple(getattr(a, f)) for f in ["graphs"]: diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index ecbc6c9e8a5e..8c4224b68c8e 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -197,6 +197,41 @@ def _check_output(tvm_out, ort_out): _check_output(tvm_out, ort_out) +def run_in_tvm( + model: ModelProto, + inputs: dict[str, np.ndarray] | None = None, + ir_version: int = 8, + opset: int = 14, +): + if ir_version is not None: + model.ir_version = ir_version + if opset is not None: + for opset_import in model.opset_import: + if opset_import.domain in ["", "ai.onnx"]: + opset_import.version = opset + break + + inputs = generate_random_inputs(model, inputs) + tvm_model = from_onnx(model, opset=opset, keep_params_in_input=True) + tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model) + tvm_model = relax.transform.LegalizeOps()(tvm_model) + tvm_model, params = relax.frontend.detach_params(tvm_model) + + with tvm.transform.PassContext(opt_level=3): + ex = tvm.compile(tvm_model, target="llvm") + vm = relax.VirtualMachine(ex, tvm.cpu()) + + input_list = [ + inputs[key.name_hint] for key in tvm_model["main"].params if key.name_hint in inputs + ] + if params: + input_list += params["main"] + + vm.set_input("main", *input_list) + vm.invoke_stateful("main") + return vm.get_outputs("main") + + @pytest.mark.parametrize( "input_names, expected_names", [ @@ -374,6 +409,80 @@ def test_matmul(dynamic): check_correctness(model, inputs) +@pytest.mark.parametrize( + ("a_dtype", "b_dtype", "a_shape", "b_shape"), + [ + (np.int16, np.int16, [2, 3], [3, 4]), + (np.uint16, np.uint16, [2, 3], [3, 4]), + (np.int16, np.uint16, [2, 1, 3, 5], [1, 2, 5, 4]), + ], +) +def test_matmulinteger16(a_dtype, b_dtype, a_shape, b_shape): + a = np.arange(np.prod(a_shape), dtype=np.int64).reshape(a_shape) + b = np.arange(np.prod(b_shape), dtype=np.int64).reshape(b_shape) + if np.issubdtype(a_dtype, np.signedinteger): + a -= a.size // 2 + if np.issubdtype(b_dtype, np.signedinteger): + b -= b.size // 2 + a = a.astype(a_dtype) + b = b.astype(b_dtype) + + out_dtype = np.uint32 if a_dtype == np.uint16 and b_dtype == np.uint16 else np.int32 + expected = np.matmul(a.astype(out_dtype), b.astype(out_dtype)) + + node = helper.make_node("MatMulInteger16", ["a", "b"], ["y"], domain="com.microsoft") + graph = helper.make_graph( + [node], + "matmulinteger16_test", + inputs=[ + helper.make_tensor_value_info("a", helper.np_dtype_to_tensor_dtype(a.dtype), a_shape), + helper.make_tensor_value_info("b", helper.np_dtype_to_tensor_dtype(b.dtype), b_shape), + ], + outputs=[ + helper.make_tensor_value_info( + "y", helper.np_dtype_to_tensor_dtype(np.dtype(out_dtype)), expected.shape + ) + ], + ) + model = helper.make_model( + graph, + producer_name="matmulinteger16_test", + opset_imports=[helper.make_opsetid("", 18), helper.make_opsetid("com.microsoft", 1)], + ) + model.ir_version = 11 + + tvm_output = run_in_tvm(model, inputs={"a": a, "b": b}, ir_version=11, opset=18) + assert isinstance(tvm_output, tvm.runtime.Tensor) + assert tvm_output.numpy().dtype == out_dtype + tvm.testing.assert_allclose(tvm_output.numpy(), expected) + + +def test_matmulinteger16_ir(): + node = helper.make_node("MatMulInteger16", ["a", "b"], ["y"], domain="com.microsoft") + graph = helper.make_graph( + [node], + "matmulinteger16_ir_test", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.UINT16, [2, 3]), + helper.make_tensor_value_info("b", TensorProto.UINT16, [3, 4]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.UINT32, [2, 4])], + ) + model = helper.make_model( + graph, + producer_name="matmulinteger16_ir_test", + opset_imports=[helper.make_opsetid("", 18), helper.make_opsetid("com.microsoft", 1)], + ) + model.ir_version = 11 + + tvm_model = from_onnx(model, opset=18, keep_params_in_input=True) + tvm_model_str = str(tvm_model) + assert "MatMulInteger16" not in tvm_model_str + assert tvm_model_str.count('R.astype(') >= 2 + assert "R.matmul" in tvm_model_str + assert 'dtype="uint32"' in tvm_model_str + + def test_concat(): verify_binary("Concat", [1, 32], [1, 32], [2, 32], attrs={"axis": 0}) @@ -3138,6 +3247,21 @@ def make_constant_node(name: str, data_type: int, dims: list[int], vals: list[in ) +def make_optional_tensor_value_info(name: str, elem_type: int, shape: list[int]): + return helper.make_value_info( + name, helper.make_optional_type_proto(helper.make_tensor_type_proto(elem_type, shape)) + ) + + +def make_optional_sequence_value_info(name: str, elem_type: int, shape: list[int]): + return helper.make_value_info( + name, + helper.make_optional_type_proto( + helper.make_sequence_type_proto(helper.make_tensor_type_proto(elem_type, shape)) + ), + ) + + def test_sequence_construct(): node, graph_inputs = construct_sequence(input_shape=[32, 32], num_tensors=2) graph = helper.make_graph( @@ -3249,6 +3373,110 @@ def test_sequence_at(): check_correctness(model) +def test_optional_get_element_tensor(): + x_shape = [2, 3] + optional_node = helper.make_node("Optional", ["x"], ["optional"]) + get_element_node = helper.make_node("OptionalGetElement", ["optional"], ["output"]) + graph = helper.make_graph( + [optional_node, get_element_node], + "test_optional_get_element_tensor", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, x_shape)], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, x_shape)], + value_info=[make_optional_tensor_value_info("optional", TensorProto.FLOAT, x_shape)], + ) + model = helper.make_model(graph, producer_name="test_optional_get_element_tensor") + check_correctness(model, opset=18, ir_version=11) + + +def test_optional_has_element_tensor(): + x_shape = [2, 3] + optional_node = helper.make_node("Optional", ["x"], ["optional"]) + has_element_node = helper.make_node("OptionalHasElement", ["optional"], ["output"]) + graph = helper.make_graph( + [optional_node, has_element_node], + "test_optional_has_element_tensor", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, x_shape)], + outputs=[helper.make_tensor_value_info("output", TensorProto.BOOL, [])], + value_info=[make_optional_tensor_value_info("optional", TensorProto.FLOAT, x_shape)], + ) + model = helper.make_model(graph, producer_name="test_optional_has_element_tensor") + check_correctness(model, opset=18, ir_version=11) + + +def test_optional_has_element_empty(): + x_shape = [2, 3] + tensor_type = helper.make_tensor_type_proto(TensorProto.FLOAT, x_shape) + optional_type = helper.make_optional_type_proto(tensor_type) + optional_node = helper.make_node("Optional", [], ["optional"], type=tensor_type) + has_element_node = helper.make_node("OptionalHasElement", ["optional"], ["output"]) + graph = helper.make_graph( + [optional_node, has_element_node], + "test_optional_has_element_empty", + inputs=[], + outputs=[helper.make_tensor_value_info("output", TensorProto.BOOL, [])], + value_info=[helper.make_value_info("optional", optional_type)], + ) + model = helper.make_model(graph, producer_name="test_optional_has_element_empty") + check_correctness(model, opset=18, ir_version=11) + + +def test_optional_has_element_empty_ir(): + x_shape = [2, 3] + tensor_type = helper.make_tensor_type_proto(TensorProto.FLOAT, x_shape) + optional_type = helper.make_optional_type_proto(tensor_type) + optional_node = helper.make_node("Optional", [], ["optional"], type=tensor_type) + has_element_node = helper.make_node("OptionalHasElement", ["optional"], ["output"]) + graph = helper.make_graph( + [optional_node, has_element_node], + "test_optional_has_element_empty_ir", + inputs=[], + outputs=[helper.make_tensor_value_info("output", TensorProto.BOOL, [])], + value_info=[helper.make_value_info("optional", optional_type)], + ) + model = helper.make_model(graph, producer_name="test_optional_has_element_empty_ir") + model.ir_version = 11 + model.opset_import[0].version = 18 + tvm_model = from_onnx(model, opset=18, keep_params_in_input=True) + + assert 'R.const(False, "bool")' in str(tvm_model) + + +def test_optional_get_element_sequence(): + seq_node, graph_inputs = construct_sequence(input_shape=[32, 32], num_tensors=4) + index = make_constant_node("index", TensorProto.INT64, (), [1]) + optional_node = helper.make_node("Optional", ["sequence"], ["optional"]) + get_element_node = helper.make_node("OptionalGetElement", ["optional"], ["unwrapped"]) + sequence_at_node = helper.make_node("SequenceAt", ["unwrapped", "index"], ["output"]) + graph = helper.make_graph( + [index, seq_node, optional_node, get_element_node, sequence_at_node], + "test_optional_get_element_sequence", + inputs=graph_inputs, + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [32, 32])], + value_info=[make_optional_sequence_value_info("optional", TensorProto.FLOAT, [32, 32])], + ) + model = helper.make_model(graph, producer_name="test_optional_get_element_sequence") + check_correctness(model, opset=18, ir_version=11) + + +def test_optional_get_element_empty_raises(): + x_shape = [2, 3] + tensor_type = helper.make_tensor_type_proto(TensorProto.FLOAT, x_shape) + optional_type = helper.make_optional_type_proto(tensor_type) + optional_node = helper.make_node("Optional", [], ["optional"], type=tensor_type) + get_element_node = helper.make_node("OptionalGetElement", ["optional"], ["output"]) + graph = helper.make_graph( + [optional_node, get_element_node], + "test_optional_get_element_empty_raises", + inputs=[], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, x_shape)], + value_info=[helper.make_value_info("optional", optional_type)], + ) + model = helper.make_model(graph, producer_name="test_optional_get_element_empty_raises") + model.opset_import[0].version = 18 + with pytest.raises(ValueError, match="empty optional"): + from_onnx(model, opset=18, keep_params_in_input=True) + + def test_symbolic_shape_deduction(): index_node = helper.make_node( "Constant", From 8a8197a0d07c55e038aedcf65ba4e515442cdff3 Mon Sep 17 00:00:00 2001 From: Aharrypotter <62729549+Aharrypotter@users.noreply.github.com> Date: Sat, 28 Mar 2026 19:38:02 +0800 Subject: [PATCH 2/4] [Relax][ONNX] Tighten Optional and MatMulInteger16 checks --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 8 +-- tests/python/relax/test_frontend_onnx.py | 60 +++++++++++++++++++ 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index e91a2c763d67..503d6c339e9e 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -359,7 +359,7 @@ def _to_numpy(x): class _EmptyOptional: - """Sentinel object used to represent an empty ONNX Optional value.""" + """Sentinel object that preserves an empty ONNX Optional during import.""" def __init__(self, type_proto: onnx.onnx_ml_pb2.TypeProto): self.type_proto = type_proto @@ -3610,7 +3610,7 @@ def _impl_v1(cls, bb, inputs, attr, params): class Optional_(OnnxOpConverter): - """Operator converter for Optional.""" + """Converts an ONNX Optional node into an erased or empty Optional representation.""" @classmethod def _impl_v15(cls, bb, inputs, attr, params): @@ -3626,7 +3626,7 @@ def _impl_v15(cls, bb, inputs, attr, params): class OptionalHasElement(OnnxOpConverter): - """Operator converter for OptionalHasElement.""" + """Converts an ONNX OptionalHasElement node into a boolean constant.""" @classmethod def _impl_v15(cls, bb, inputs, attr, params): @@ -3640,7 +3640,7 @@ def _impl_v15(cls, bb, inputs, attr, params): class OptionalGetElement(OnnxOpConverter): - """Operator converter for OptionalGetElement.""" + """Converts an ONNX OptionalGetElement node by unwrapping a non-empty Optional.""" @classmethod def _impl_v15(cls, bb, inputs, attr, params): diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 8c4224b68c8e..4b5bf056efe3 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -483,6 +483,28 @@ def test_matmulinteger16_ir(): assert 'dtype="uint32"' in tvm_model_str +def test_matmulinteger16_invalid_dtype_raises(): + node = helper.make_node("MatMulInteger16", ["a", "b"], ["y"], domain="com.microsoft") + graph = helper.make_graph( + [node], + "matmulinteger16_invalid_dtype_test", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.INT8, [2, 3]), + helper.make_tensor_value_info("b", TensorProto.UINT16, [3, 4]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.INT32, [2, 4])], + ) + model = helper.make_model( + graph, + producer_name="matmulinteger16_invalid_dtype_test", + opset_imports=[helper.make_opsetid("", 18), helper.make_opsetid("com.microsoft", 1)], + ) + model.ir_version = 11 + + with pytest.raises(ValueError, match="input A"): + from_onnx(model, opset=18, keep_params_in_input=True) + + def test_concat(): verify_binary("Concat", [1, 32], [1, 32], [2, 32], attrs={"axis": 0}) @@ -3441,6 +3463,27 @@ def test_optional_has_element_empty_ir(): assert 'R.const(False, "bool")' in str(tvm_model) +def test_optional_get_element_tensor_ir(): + x_shape = [2, 3] + optional_node = helper.make_node("Optional", ["x"], ["optional"]) + get_element_node = helper.make_node("OptionalGetElement", ["optional"], ["output"]) + graph = helper.make_graph( + [optional_node, get_element_node], + "test_optional_get_element_tensor_ir", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, x_shape)], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, x_shape)], + value_info=[make_optional_tensor_value_info("optional", TensorProto.FLOAT, x_shape)], + ) + model = helper.make_model(graph, producer_name="test_optional_get_element_tensor_ir") + model.ir_version = 11 + model.opset_import[0].version = 18 + tvm_model = from_onnx(model, opset=18, keep_params_in_input=True) + + tvm_model_str = str(tvm_model) + assert "Optional" not in tvm_model_str + assert "OptionalGetElement" not in tvm_model_str + + def test_optional_get_element_sequence(): seq_node, graph_inputs = construct_sequence(input_shape=[32, 32], num_tensors=4) index = make_constant_node("index", TensorProto.INT64, (), [1]) @@ -3458,6 +3501,23 @@ def test_optional_get_element_sequence(): check_correctness(model, opset=18, ir_version=11) +def test_optional_without_input_requires_type_attr(): + tensor_type = helper.make_tensor_type_proto(TensorProto.FLOAT, [2, 3]) + optional_type = helper.make_optional_type_proto(tensor_type) + optional_node = helper.make_node("Optional", [], ["optional"]) + graph = helper.make_graph( + [optional_node], + "test_optional_without_input_requires_type_attr", + inputs=[], + outputs=[helper.make_value_info("optional", optional_type)], + ) + model = helper.make_model(graph, producer_name="test_optional_without_input_requires_type_attr") + model.opset_import[0].version = 18 + + with pytest.raises(ValueError, match="type attribute"): + from_onnx(model, opset=18, keep_params_in_input=True) + + def test_optional_get_element_empty_raises(): x_shape = [2, 3] tensor_type = helper.make_tensor_type_proto(TensorProto.FLOAT, x_shape) From f6de1f133be324bd0bd99ac64c951dbeda14b19a Mon Sep 17 00:00:00 2001 From: Aharrypotter <62729549+Aharrypotter@users.noreply.github.com> Date: Sat, 28 Mar 2026 23:16:56 +0800 Subject: [PATCH 3/4] [Relax][ONNX] Address Optional review feedback --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 6 +- tests/python/relax/test_frontend_onnx.py | 56 ++++++++++++++++--- 2 files changed, 50 insertions(+), 12 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index bcb8419ed1ff..24ea74ec82cb 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -3728,9 +3728,9 @@ class OptionalHasElement(OnnxOpConverter): @classmethod def _impl_v15(cls, bb, inputs, attr, params): - if len(inputs) > 1: - raise ValueError(f"OptionalHasElement accepts at most one input, but got {len(inputs)}") - if len(inputs) == 0 or inputs[0] is None or _is_empty_optional(inputs[0]): + if len(inputs) != 1: + raise ValueError(f"OptionalHasElement expects one input, but got {len(inputs)}") + if inputs[0] is None or _is_empty_optional(inputs[0]): return relax.const(False, dtype="bool") return relax.const(True, dtype="bool") diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 3bb18f56bbbf..ea08fd0a5805 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -232,6 +232,30 @@ def run_in_tvm( return vm.get_outputs("main") +def collect_relax_call_ops(func: relax.Function) -> list[str]: + op_names = [] + + def fvisit(expr): + if isinstance(expr, relax.Call) and isinstance(expr.op, tvm.ir.Op): + op_names.append(expr.op.name) + + relax.analysis.post_order_visit(func.body, fvisit) + return op_names + + +def collect_scalar_constants(func: relax.Function) -> list[bool | int | float]: + values = [] + + def fvisit(expr): + if isinstance(expr, relax.Constant): + value = expr.data.numpy() + if value.shape == (): + values.append(value.item()) + + relax.analysis.post_order_visit(func.body, fvisit) + return values + + @pytest.mark.parametrize( "input_names, expected_names", [ @@ -476,11 +500,10 @@ def test_matmulinteger16_ir(): model.ir_version = 11 tvm_model = from_onnx(model, opset=18, keep_params_in_input=True) - tvm_model_str = str(tvm_model) - assert "MatMulInteger16" not in tvm_model_str - assert tvm_model_str.count('R.astype(') >= 2 - assert "R.matmul" in tvm_model_str - assert 'dtype="uint32"' in tvm_model_str + call_ops = collect_relax_call_ops(tvm_model["main"]) + assert call_ops.count("relax.astype") == 2 + assert "relax.matmul" in call_ops + assert tvm_model["main"].ret_struct_info.dtype == "uint32" def test_matmulinteger16_invalid_dtype_raises(): @@ -3491,7 +3514,8 @@ def test_optional_has_element_empty_ir(): model.opset_import[0].version = 18 tvm_model = from_onnx(model, opset=18, keep_params_in_input=True) - assert 'R.const(False, "bool")' in str(tvm_model) + assert collect_relax_call_ops(tvm_model["main"]) == [] + assert False in collect_scalar_constants(tvm_model["main"]) def test_optional_get_element_tensor_ir(): @@ -3510,9 +3534,8 @@ def test_optional_get_element_tensor_ir(): model.opset_import[0].version = 18 tvm_model = from_onnx(model, opset=18, keep_params_in_input=True) - tvm_model_str = str(tvm_model) - assert "Optional" not in tvm_model_str - assert "OptionalGetElement" not in tvm_model_str + assert collect_relax_call_ops(tvm_model["main"]) == [] + assert tvm_model["main"].ret_struct_info.dtype == "float32" def test_optional_get_element_sequence(): @@ -3549,6 +3572,21 @@ def test_optional_without_input_requires_type_attr(): from_onnx(model, opset=18, keep_params_in_input=True) +def test_optional_has_element_requires_one_input(): + has_element_node = helper.make_node("OptionalHasElement", [], ["output"]) + graph = helper.make_graph( + [has_element_node], + "test_optional_has_element_requires_one_input", + inputs=[], + outputs=[helper.make_tensor_value_info("output", TensorProto.BOOL, [])], + ) + model = helper.make_model(graph, producer_name="test_optional_has_element_requires_one_input") + model.opset_import[0].version = 18 + + with pytest.raises(ValueError, match="expects one input"): + from_onnx(model, opset=18, keep_params_in_input=True) + + def test_optional_get_element_empty_raises(): x_shape = [2, 3] tensor_type = helper.make_tensor_type_proto(TensorProto.FLOAT, x_shape) From e18e0183dcdab78beab7ffa9ec5e270cc431eb90 Mon Sep 17 00:00:00 2001 From: Aharrypotter <62729549+Aharrypotter@users.noreply.github.com> Date: Mon, 30 Mar 2026 10:48:30 +0800 Subject: [PATCH 4/4] [Relax][ONNX] Guard empty Optional graph outputs and remove dead Optional path --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 15 ++++++++++++--- tests/python/relax/test_frontend_onnx.py | 17 +++++++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 74532aea3b46..4af7115e5ca6 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -4431,7 +4431,18 @@ def from_onnx(self, graph: onnx.onnx_ml_pb2.ModelProto, opset: int) -> IRModule: self._check_for_unsupported_ops(graph) self._construct_nodes(graph) - outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output] + # now return the outputs + output_names = [self._parse_value_proto(output) for output in graph.output] + outputs = [] + for output_name in output_names: + output_value = self._nodes[output_name] + if _is_empty_optional(output_value): + raise ValueError( + "ONNX graph output " + f"{output_name} is an empty optional. Empty optional graph outputs " + "are not supported by the Relax ONNX frontend." + ) + outputs.append(output_value) outputs = outputs[0] if len(outputs) == 1 else relax.Tuple(outputs) if has_if: @@ -4631,8 +4642,6 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto): if op_name in return_tuple_ops: outputs_num = 1 - elif _is_empty_optional(op): - outputs_num = 1 elif not isinstance(op, relax.Tuple): if isinstance(op.struct_info, relax.TupleStructInfo): # This is a var bound to a tuple. We need to unpack it and create diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 463125258533..c848ef91d6a1 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -3579,6 +3579,23 @@ def test_optional_without_input_requires_type_attr(): from_onnx(model, opset=18, keep_params_in_input=True) +def test_empty_optional_graph_output_raises(): + tensor_type = helper.make_tensor_type_proto(TensorProto.FLOAT, [2, 3]) + optional_type = helper.make_optional_type_proto(tensor_type) + optional_node = helper.make_node("Optional", [], ["optional"], type=tensor_type) + graph = helper.make_graph( + [optional_node], + "test_empty_optional_graph_output_raises", + inputs=[], + outputs=[helper.make_value_info("optional", optional_type)], + ) + model = helper.make_model(graph, producer_name="test_empty_optional_graph_output_raises") + model.opset_import[0].version = 18 + + with pytest.raises(ValueError, match="Empty optional graph outputs are not supported"): + from_onnx(model, opset=18, keep_params_in_input=True) + + def test_optional_has_element_requires_one_input(): has_element_node = helper.make_node("OptionalHasElement", [], ["output"]) graph = helper.make_graph(