diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 4cc4e99b7b02..4af7115e5ca6 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -318,6 +318,37 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.op.matmul(inputs[0], inputs[1]) +class MatMulInteger16(OnnxOpConverter): + """Converts an ONNX MatMulInteger16 node into an equivalent Relax expression.""" + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + if len(inputs) != 2: + raise ValueError(f"MatMulInteger16 expects two inputs, but got {len(inputs)}") + a, b = inputs + valid_types = ["int16", "uint16"] + if a.struct_info.dtype not in valid_types: + raise ValueError( + "MatMulInteger16 expects input A to have int16 or uint16 dtype, " + f"but got {a.struct_info.dtype}" + ) + if b.struct_info.dtype not in valid_types: + raise ValueError( + "MatMulInteger16 expects input B to have int16 or uint16 dtype, " + f"but got {b.struct_info.dtype}" + ) + + out_dtype = ( + "uint32" + if a.struct_info.dtype == "uint16" and b.struct_info.dtype == "uint16" + else "int32" + ) + return relax.op.matmul( + relax.op.astype(a, out_dtype), + relax.op.astype(b, out_dtype), + ) + + def _to_numpy(x): if isinstance(x, relax.PrimValue): x = x.value @@ -328,6 +359,19 @@ def _to_numpy(x): return x.data.numpy() +class _EmptyOptional: + """Sentinel object that preserves an empty ONNX Optional during import.""" + + def __init__(self, type_proto: onnx.onnx_ml_pb2.TypeProto): + self.type_proto = type_proto + + +def _is_empty_optional(value: Any) -> bool: + """Returns whether the given value represents an empty ONNX Optional.""" + + return isinstance(value, _EmptyOptional) + + class BinaryBase(OnnxOpConverter): """Converts an onnx BinaryBase node into an equivalent Relax expression.""" @@ -3686,6 +3730,50 @@ def _impl_v1(cls, bb, inputs, attr, params): ) +class Optional_(OnnxOpConverter): + """Converts an ONNX Optional node into an erased or empty Optional representation.""" + + @classmethod + def _impl_v15(cls, bb, inputs, attr, params): + if len(inputs) > 1: + raise ValueError(f"Optional accepts at most one input, but got {len(inputs)}") + if len(inputs) == 0 or inputs[0] is None: + if "type" not in attr: + raise ValueError("Optional without an input must specify the type attribute.") + return _EmptyOptional(attr["type"]) + return inputs[0] + + _impl_v18 = _impl_v15 + + +class OptionalHasElement(OnnxOpConverter): + """Converts an ONNX OptionalHasElement node into a boolean constant.""" + + @classmethod + def _impl_v15(cls, bb, inputs, attr, params): + if len(inputs) != 1: + raise ValueError(f"OptionalHasElement expects one input, but got {len(inputs)}") + if inputs[0] is None or _is_empty_optional(inputs[0]): + return relax.const(False, dtype="bool") + return relax.const(True, dtype="bool") + + _impl_v18 = _impl_v15 + + +class OptionalGetElement(OnnxOpConverter): + """Converts an ONNX OptionalGetElement node by unwrapping a non-empty Optional.""" + + @classmethod + def _impl_v15(cls, bb, inputs, attr, params): + if len(inputs) != 1: + raise ValueError(f"OptionalGetElement expects one input, but got {len(inputs)}") + if inputs[0] is None or _is_empty_optional(inputs[0]): + raise ValueError("OptionalGetElement cannot access an empty optional.") + return inputs[0] + + _impl_v18 = _impl_v15 + + class SequenceConstruct(OnnxOpConverter): """Operator converter for sequence construction op.""" @@ -4111,9 +4199,9 @@ def _impl_v10(cls, bb, inputs, attr, params): def _get_convert_map(): return { # defs/experimental - # "Optional": Optional_, - # "OptionalHasElement": OptionalHasElement, - # "OptionalGetElement": OptionalGetElement, + "Optional": Optional_, + "OptionalHasElement": OptionalHasElement, + "OptionalGetElement": OptionalGetElement, # Binary operators "Add": Add, "Sub": Sub, @@ -4184,7 +4272,7 @@ def _get_convert_map(): "Gemm": Gemm, "MatMul": MatMul, "MatMulInteger": MatMulInteger, - # "MatMulInteger16": MatMulInteger16, + "MatMulInteger16": MatMulInteger16, "Reshape": Reshape, "Sigmoid": Sigmoid, "Softmax": Softmax, @@ -4343,7 +4431,18 @@ def from_onnx(self, graph: onnx.onnx_ml_pb2.ModelProto, opset: int) -> IRModule: self._check_for_unsupported_ops(graph) self._construct_nodes(graph) - outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output] + # now return the outputs + output_names = [self._parse_value_proto(output) for output in graph.output] + outputs = [] + for output_name in output_names: + output_value = self._nodes[output_name] + if _is_empty_optional(output_value): + raise ValueError( + "ONNX graph output " + f"{output_name} is an empty optional. Empty optional graph outputs " + "are not supported by the Relax ONNX frontend." + ) + outputs.append(output_value) outputs = outputs[0] if len(outputs) == 1 else relax.Tuple(outputs) if has_if: @@ -4515,6 +4614,8 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto): "Squeeze", ] return_tuple_ops = [ + "Optional", + "OptionalGetElement", "SequenceConstruct", "SequenceEmpty", "SequenceErase", @@ -4533,7 +4634,8 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto): try: op = self._convert_operator(op_name, inputs, attr, self.opset) # Create struct information for the new operator. - op = self.bb.normalize(op) + if isinstance(op, relax.Expr): + op = self.bb.normalize(op) except TVMError as err: print(f"Error converting operator {op_name}, with inputs: {inputs}") raise err @@ -4585,11 +4687,11 @@ def _parse_attr(self, attr_proto: onnx.onnx_ml_pb2.AttributeProto) -> dict[str, if list(getattr(a, f)): assert a.name not in attrs, "Only one type of attr is allowed" attrs[a.name] = tuple(getattr(a, f)) - for f in ["t"]: - if a.HasField(f): + for f in ["t", "tp"]: + if hasattr(a, f) and a.HasField(f): attrs[a.name] = getattr(a, f) - for f in ["tensors"]: - if list(getattr(a, f)): + for f in ["tensors", "type_protos"]: + if hasattr(a, f) and list(getattr(a, f)): assert a.name not in attrs, "Only one type of attr is allowed" attrs[a.name] = tuple(getattr(a, f)) for f in ["graphs"]: diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index d04b0c2f33f6..c848ef91d6a1 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -197,6 +197,65 @@ def _check_output(tvm_out, ort_out): _check_output(tvm_out, ort_out) +def run_in_tvm( + model: ModelProto, + inputs: dict[str, np.ndarray] | None = None, + ir_version: int = 8, + opset: int = 14, +): + if ir_version is not None: + model.ir_version = ir_version + if opset is not None: + for opset_import in model.opset_import: + if opset_import.domain in ["", "ai.onnx"]: + opset_import.version = opset + break + + inputs = generate_random_inputs(model, inputs) + tvm_model = from_onnx(model, opset=opset, keep_params_in_input=True) + tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model) + tvm_model = relax.transform.LegalizeOps()(tvm_model) + tvm_model, params = relax.frontend.detach_params(tvm_model) + + with tvm.transform.PassContext(opt_level=3): + ex = tvm.compile(tvm_model, target="llvm") + vm = relax.VirtualMachine(ex, tvm.cpu()) + + input_list = [ + inputs[key.name_hint] for key in tvm_model["main"].params if key.name_hint in inputs + ] + if params: + input_list += params["main"] + + vm.set_input("main", *input_list) + vm.invoke_stateful("main") + return vm.get_outputs("main") + + +def collect_relax_call_ops(func: relax.Function) -> list[str]: + op_names = [] + + def fvisit(expr): + if isinstance(expr, relax.Call) and isinstance(expr.op, tvm.ir.Op): + op_names.append(expr.op.name) + + relax.analysis.post_order_visit(func.body, fvisit) + return op_names + + +def collect_scalar_constants(func: relax.Function) -> list[bool | int | float]: + values = [] + + def fvisit(expr): + if isinstance(expr, relax.Constant): + value = expr.data.numpy() + if value.shape == (): + values.append(value.item()) + + relax.analysis.post_order_visit(func.body, fvisit) + return values + + @pytest.mark.parametrize( "input_names, expected_names", [ @@ -374,6 +433,101 @@ def test_matmul(dynamic): check_correctness(model, inputs) +@pytest.mark.parametrize( + ("a_dtype", "b_dtype", "a_shape", "b_shape"), + [ + (np.int16, np.int16, [2, 3], [3, 4]), + (np.uint16, np.uint16, [2, 3], [3, 4]), + (np.int16, np.uint16, [2, 1, 3, 5], [1, 2, 5, 4]), + ], +) +def test_matmulinteger16(a_dtype, b_dtype, a_shape, b_shape): + a = np.arange(np.prod(a_shape), dtype=np.int64).reshape(a_shape) + b = np.arange(np.prod(b_shape), dtype=np.int64).reshape(b_shape) + if np.issubdtype(a_dtype, np.signedinteger): + a -= a.size // 2 + if np.issubdtype(b_dtype, np.signedinteger): + b -= b.size // 2 + a = a.astype(a_dtype) + b = b.astype(b_dtype) + + out_dtype = np.uint32 if a_dtype == np.uint16 and b_dtype == np.uint16 else np.int32 + expected = np.matmul(a.astype(out_dtype), b.astype(out_dtype)) + + node = helper.make_node("MatMulInteger16", ["a", "b"], ["y"], domain="com.microsoft") + graph = helper.make_graph( + [node], + "matmulinteger16_test", + inputs=[ + helper.make_tensor_value_info("a", helper.np_dtype_to_tensor_dtype(a.dtype), a_shape), + helper.make_tensor_value_info("b", helper.np_dtype_to_tensor_dtype(b.dtype), b_shape), + ], + outputs=[ + helper.make_tensor_value_info( + "y", helper.np_dtype_to_tensor_dtype(np.dtype(out_dtype)), expected.shape + ) + ], + ) + model = helper.make_model( + graph, + producer_name="matmulinteger16_test", + opset_imports=[helper.make_opsetid("", 18), helper.make_opsetid("com.microsoft", 1)], + ) + model.ir_version = 11 + + tvm_output = run_in_tvm(model, inputs={"a": a, "b": b}, ir_version=11, opset=18) + assert isinstance(tvm_output, tvm.runtime.Tensor) + assert tvm_output.numpy().dtype == out_dtype + tvm.testing.assert_allclose(tvm_output.numpy(), expected) + + +def test_matmulinteger16_ir(): + node = helper.make_node("MatMulInteger16", ["a", "b"], ["y"], domain="com.microsoft") + graph = helper.make_graph( + [node], + "matmulinteger16_ir_test", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.UINT16, [2, 3]), + helper.make_tensor_value_info("b", TensorProto.UINT16, [3, 4]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.UINT32, [2, 4])], + ) + model = helper.make_model( + graph, + producer_name="matmulinteger16_ir_test", + opset_imports=[helper.make_opsetid("", 18), helper.make_opsetid("com.microsoft", 1)], + ) + model.ir_version = 11 + + tvm_model = from_onnx(model, opset=18, keep_params_in_input=True) + call_ops = collect_relax_call_ops(tvm_model["main"]) + assert call_ops.count("relax.astype") == 2 + assert "relax.matmul" in call_ops + assert tvm_model["main"].ret_struct_info.dtype == "uint32" + + +def test_matmulinteger16_invalid_dtype_raises(): + node = helper.make_node("MatMulInteger16", ["a", "b"], ["y"], domain="com.microsoft") + graph = helper.make_graph( + [node], + "matmulinteger16_invalid_dtype_test", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.INT8, [2, 3]), + helper.make_tensor_value_info("b", TensorProto.UINT16, [3, 4]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.INT32, [2, 4])], + ) + model = helper.make_model( + graph, + producer_name="matmulinteger16_invalid_dtype_test", + opset_imports=[helper.make_opsetid("", 18), helper.make_opsetid("com.microsoft", 1)], + ) + model.ir_version = 11 + + with pytest.raises(ValueError, match="input A"): + from_onnx(model, opset=18, keep_params_in_input=True) + + def test_concat(): verify_binary("Concat", [1, 32], [1, 32], [2, 32], attrs={"axis": 0}) @@ -3176,6 +3330,21 @@ def make_constant_node(name: str, data_type: int, dims: list[int], vals: list[in ) +def make_optional_tensor_value_info(name: str, elem_type: int, shape: list[int]): + return helper.make_value_info( + name, helper.make_optional_type_proto(helper.make_tensor_type_proto(elem_type, shape)) + ) + + +def make_optional_sequence_value_info(name: str, elem_type: int, shape: list[int]): + return helper.make_value_info( + name, + helper.make_optional_type_proto( + helper.make_sequence_type_proto(helper.make_tensor_type_proto(elem_type, shape)) + ), + ) + + def test_sequence_construct(): node, graph_inputs = construct_sequence(input_shape=[32, 32], num_tensors=2) graph = helper.make_graph( @@ -3287,6 +3456,180 @@ def test_sequence_at(): check_correctness(model) +def test_optional_get_element_tensor(): + x_shape = [2, 3] + optional_node = helper.make_node("Optional", ["x"], ["optional"]) + get_element_node = helper.make_node("OptionalGetElement", ["optional"], ["output"]) + graph = helper.make_graph( + [optional_node, get_element_node], + "test_optional_get_element_tensor", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, x_shape)], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, x_shape)], + value_info=[make_optional_tensor_value_info("optional", TensorProto.FLOAT, x_shape)], + ) + model = helper.make_model(graph, producer_name="test_optional_get_element_tensor") + check_correctness(model, opset=18, ir_version=11) + + +def test_optional_has_element_tensor(): + x_shape = [2, 3] + optional_node = helper.make_node("Optional", ["x"], ["optional"]) + has_element_node = helper.make_node("OptionalHasElement", ["optional"], ["output"]) + graph = helper.make_graph( + [optional_node, has_element_node], + "test_optional_has_element_tensor", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, x_shape)], + outputs=[helper.make_tensor_value_info("output", TensorProto.BOOL, [])], + value_info=[make_optional_tensor_value_info("optional", TensorProto.FLOAT, x_shape)], + ) + model = helper.make_model(graph, producer_name="test_optional_has_element_tensor") + check_correctness(model, opset=18, ir_version=11) + + +def test_optional_has_element_empty(): + x_shape = [2, 3] + tensor_type = helper.make_tensor_type_proto(TensorProto.FLOAT, x_shape) + optional_type = helper.make_optional_type_proto(tensor_type) + optional_node = helper.make_node("Optional", [], ["optional"], type=tensor_type) + has_element_node = helper.make_node("OptionalHasElement", ["optional"], ["output"]) + graph = helper.make_graph( + [optional_node, has_element_node], + "test_optional_has_element_empty", + inputs=[], + outputs=[helper.make_tensor_value_info("output", TensorProto.BOOL, [])], + value_info=[helper.make_value_info("optional", optional_type)], + ) + model = helper.make_model(graph, producer_name="test_optional_has_element_empty") + check_correctness(model, opset=18, ir_version=11) + + +def test_optional_has_element_empty_ir(): + x_shape = [2, 3] + tensor_type = helper.make_tensor_type_proto(TensorProto.FLOAT, x_shape) + optional_type = helper.make_optional_type_proto(tensor_type) + optional_node = helper.make_node("Optional", [], ["optional"], type=tensor_type) + has_element_node = helper.make_node("OptionalHasElement", ["optional"], ["output"]) + graph = helper.make_graph( + [optional_node, has_element_node], + "test_optional_has_element_empty_ir", + inputs=[], + outputs=[helper.make_tensor_value_info("output", TensorProto.BOOL, [])], + value_info=[helper.make_value_info("optional", optional_type)], + ) + model = helper.make_model(graph, producer_name="test_optional_has_element_empty_ir") + model.ir_version = 11 + model.opset_import[0].version = 18 + tvm_model = from_onnx(model, opset=18, keep_params_in_input=True) + + assert collect_relax_call_ops(tvm_model["main"]) == [] + assert False in collect_scalar_constants(tvm_model["main"]) + + +def test_optional_get_element_tensor_ir(): + x_shape = [2, 3] + optional_node = helper.make_node("Optional", ["x"], ["optional"]) + get_element_node = helper.make_node("OptionalGetElement", ["optional"], ["output"]) + graph = helper.make_graph( + [optional_node, get_element_node], + "test_optional_get_element_tensor_ir", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, x_shape)], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, x_shape)], + value_info=[make_optional_tensor_value_info("optional", TensorProto.FLOAT, x_shape)], + ) + model = helper.make_model(graph, producer_name="test_optional_get_element_tensor_ir") + model.ir_version = 11 + model.opset_import[0].version = 18 + tvm_model = from_onnx(model, opset=18, keep_params_in_input=True) + + assert collect_relax_call_ops(tvm_model["main"]) == [] + assert tvm_model["main"].ret_struct_info.dtype == "float32" + + +def test_optional_get_element_sequence(): + seq_node, graph_inputs = construct_sequence(input_shape=[32, 32], num_tensors=4) + index = make_constant_node("index", TensorProto.INT64, (), [1]) + optional_node = helper.make_node("Optional", ["sequence"], ["optional"]) + get_element_node = helper.make_node("OptionalGetElement", ["optional"], ["unwrapped"]) + sequence_at_node = helper.make_node("SequenceAt", ["unwrapped", "index"], ["output"]) + graph = helper.make_graph( + [index, seq_node, optional_node, get_element_node, sequence_at_node], + "test_optional_get_element_sequence", + inputs=graph_inputs, + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [32, 32])], + value_info=[make_optional_sequence_value_info("optional", TensorProto.FLOAT, [32, 32])], + ) + model = helper.make_model(graph, producer_name="test_optional_get_element_sequence") + check_correctness(model, opset=18, ir_version=11) + + +def test_optional_without_input_requires_type_attr(): + tensor_type = helper.make_tensor_type_proto(TensorProto.FLOAT, [2, 3]) + optional_type = helper.make_optional_type_proto(tensor_type) + optional_node = helper.make_node("Optional", [], ["optional"]) + graph = helper.make_graph( + [optional_node], + "test_optional_without_input_requires_type_attr", + inputs=[], + outputs=[helper.make_value_info("optional", optional_type)], + ) + model = helper.make_model(graph, producer_name="test_optional_without_input_requires_type_attr") + model.opset_import[0].version = 18 + + with pytest.raises(ValueError, match="type attribute"): + from_onnx(model, opset=18, keep_params_in_input=True) + + +def test_empty_optional_graph_output_raises(): + tensor_type = helper.make_tensor_type_proto(TensorProto.FLOAT, [2, 3]) + optional_type = helper.make_optional_type_proto(tensor_type) + optional_node = helper.make_node("Optional", [], ["optional"], type=tensor_type) + graph = helper.make_graph( + [optional_node], + "test_empty_optional_graph_output_raises", + inputs=[], + outputs=[helper.make_value_info("optional", optional_type)], + ) + model = helper.make_model(graph, producer_name="test_empty_optional_graph_output_raises") + model.opset_import[0].version = 18 + + with pytest.raises(ValueError, match="Empty optional graph outputs are not supported"): + from_onnx(model, opset=18, keep_params_in_input=True) + + +def test_optional_has_element_requires_one_input(): + has_element_node = helper.make_node("OptionalHasElement", [], ["output"]) + graph = helper.make_graph( + [has_element_node], + "test_optional_has_element_requires_one_input", + inputs=[], + outputs=[helper.make_tensor_value_info("output", TensorProto.BOOL, [])], + ) + model = helper.make_model(graph, producer_name="test_optional_has_element_requires_one_input") + model.opset_import[0].version = 18 + + with pytest.raises(ValueError, match="expects one input"): + from_onnx(model, opset=18, keep_params_in_input=True) + + +def test_optional_get_element_empty_raises(): + x_shape = [2, 3] + tensor_type = helper.make_tensor_type_proto(TensorProto.FLOAT, x_shape) + optional_type = helper.make_optional_type_proto(tensor_type) + optional_node = helper.make_node("Optional", [], ["optional"], type=tensor_type) + get_element_node = helper.make_node("OptionalGetElement", ["optional"], ["output"]) + graph = helper.make_graph( + [optional_node, get_element_node], + "test_optional_get_element_empty_raises", + inputs=[], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, x_shape)], + value_info=[helper.make_value_info("optional", optional_type)], + ) + model = helper.make_model(graph, producer_name="test_optional_get_element_empty_raises") + model.opset_import[0].version = 18 + with pytest.raises(ValueError, match="empty optional"): + from_onnx(model, opset=18, keep_params_in_input=True) + + def test_symbolic_shape_deduction(): index_node = helper.make_node( "Constant",