diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index a1173171252b..c858b5a87102 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -317,6 +317,37 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.op.matmul(inputs[0], inputs[1]) +class MatMulInteger16(OnnxOpConverter): + """Converts an ONNX MatMulInteger16 node into an equivalent Relax expression.""" + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + if len(inputs) != 2: + raise ValueError(f"MatMulInteger16 expects two inputs, but got {len(inputs)}") + a, b = inputs + valid_types = ["int16", "uint16"] + if a.struct_info.dtype not in valid_types: + raise ValueError( + "MatMulInteger16 expects input A to have int16 or uint16 dtype, " + f"but got {a.struct_info.dtype}" + ) + if b.struct_info.dtype not in valid_types: + raise ValueError( + "MatMulInteger16 expects input B to have int16 or uint16 dtype, " + f"but got {b.struct_info.dtype}" + ) + + out_dtype = ( + "uint32" + if a.struct_info.dtype == "uint16" and b.struct_info.dtype == "uint16" + else "int32" + ) + return relax.op.matmul( + relax.op.astype(a, out_dtype), + relax.op.astype(b, out_dtype), + ) + + def _to_numpy(x): if isinstance(x, relax.PrimValue): x = x.value @@ -327,6 +358,19 @@ def _to_numpy(x): return x.data.numpy() +class _EmptyOptional: + """Sentinel object that preserves an empty ONNX Optional during import.""" + + def __init__(self, type_proto: onnx.onnx_ml_pb2.TypeProto): + self.type_proto = type_proto + + +def _is_empty_optional(value: Any) -> bool: + """Returns whether the given value represents an empty ONNX Optional.""" + + return isinstance(value, _EmptyOptional) + + class BinaryBase(OnnxOpConverter): """Converts an onnx BinaryBase node into an equivalent Relax expression.""" @@ -690,38 +734,55 @@ def _impl_v1(cls, bb, inputs, attr, params): def _impl_v13(cls, bb, inputs, attr, params): data = inputs[0] axes = get_constant(inputs[1], params) + data_ndim = _get_known_tensor_rank(data) - # Handle ONNX shape inference if isinstance(data, relax.PrimValue) and isinstance(axes, relax.Constant): - axes = axes.data.numpy().tolist() - if axes == [0]: + constant_axes = _normalize_constant_axes( + list(map(int, axes.data.numpy().tolist())), 1, "Unsqueeze" + ) + if constant_axes == [0]: return relax.ShapeExpr([data.value]) - else: - raise NotImplementedError( - "Unsqueeze with symbolic axes and non-zero axes is not supported." - ) - # If input is a constant, compute directly + raise NotImplementedError("Unsqueeze with symbolic scalar inputs only supports axis 0.") if isinstance(data, relax.Constant) and isinstance(axes, relax.Constant): - axes = axes.data.numpy().tolist() + constant_axes = _normalize_constant_axes( + list(map(int, axes.data.numpy().tolist())), + data.data.numpy().ndim + axes.data.numpy().size, + "Unsqueeze", + ) + constant_axes = sorted(constant_axes) expanded = data.data.numpy() if len(expanded.shape) == 0: - # Special case implying input is a scalar, wrap it as a list. - if 0 in axes: - axes.remove(0) expanded = [expanded] - for axis in axes: + constant_axes = [axis - 1 for axis in constant_axes if axis != 0] + for axis in constant_axes: expanded = _np.expand_dims(expanded, axis=axis) return relax.const(expanded, data.struct_info.dtype) if isinstance(axes, relax.Constant): - constant_axes = list(axes.data.numpy()) - constant_axes = list(map(int, constant_axes)) + if data_ndim is None: + raise ValueError("Unsqueeze requires a statically known input rank.") + constant_axes = _normalize_constant_axes( + list(map(int, axes.data.numpy().tolist())), + data_ndim + axes.data.numpy().size, + "Unsqueeze", + ) constant_axes = sorted(constant_axes) for axis in constant_axes: data = relax.op.expand_dims(data, axis=axis) return data - raise NotImplementedError("Unsqueeze with dynamic axes is not supported.") + if data_ndim is None: + raise ValueError("Unsqueeze with dynamic axes requires a statically known input rank.") + axes_len = _get_known_tensor_length(axes) + if axes_len is None: + raise ValueError("Unsqueeze requires a statically known axes length.") + data_shape = bb.normalize(relax.op.shape_of(data)) + data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape)) + output_shape_tensor = _build_unsqueezed_shape_tensor(bb, data_shape_tensor, axes, data_ndim) + output_shape = _tensor_to_shape_expr( + bb, output_shape_tensor, data_ndim + axes_len, "unsqueeze_dim" + ) + return relax.op.reshape(data, output_shape) class Concat(OnnxOpConverter): @@ -1440,14 +1501,37 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.const(out_data, data.struct_info.dtype) if isinstance(data, relax.ShapeExpr): - if axis == (0,): + shape_tensor_ndim = 1 + if axis is None: + if len(data) == 1: + return relax.PrimValue(data[0]) + return data + normalized_axes = _normalize_constant_axes(list(axis), shape_tensor_ndim, "Squeeze") + if normalized_axes == [0] and len(data) == 1: return relax.PrimValue(data[0]) - else: - raise NotImplementedError( - "Squeeze with symbolic axes and non-zero axes is not supported." - ) + raise NotImplementedError( + "Squeeze on symbolic shape tensors only supports removing the sole axis." + ) + + if axis is None: + return relax.op.squeeze(data) - return relax.op.squeeze(data, axis) + if isinstance(axis, tuple): + return relax.op.squeeze(data, list(axis)) + + data_ndim = _get_known_tensor_rank(data) + if data_ndim is None: + raise ValueError("Squeeze with dynamic axes requires a statically known input rank.") + axes_len = _get_known_tensor_length(axis) + if axes_len is None: + raise ValueError("Squeeze requires a statically known axes length.") + data_shape = bb.normalize(relax.op.shape_of(data)) + data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape)) + output_shape_tensor = _build_squeezed_shape_tensor(bb, data_shape_tensor, axis, data_ndim) + output_shape = _tensor_to_shape_expr( + bb, output_shape_tensor, data_ndim - axes_len, "squeeze_dim" + ) + return relax.op.reshape(data, output_shape) class Constant(OnnxOpConverter): @@ -1844,68 +1928,308 @@ def get_prim_value_list(values): return new_values +def _get_known_tensor_rank(expr: relax.Expr) -> int | None: + """Return the statically known rank of an expression when available.""" + + if isinstance(expr, relax.Constant): + return len(expr.data.numpy().shape) + if isinstance(expr, relax.ShapeExpr): + return 1 + if isinstance(expr, relax.PrimValue): + return 0 + struct_info = expr.struct_info + if isinstance(struct_info, relax.TensorStructInfo): + return None if struct_info.ndim == -1 else struct_info.ndim + return None + + +def _get_known_tensor_length(expr: relax.Expr | None) -> int | None: + """Return the statically known length of a 1-D tensor-like expression.""" + + if expr is None: + return None + if isinstance(expr, relax.Constant): + np_value = expr.data.numpy() + if np_value.ndim != 1: + raise ValueError(f"Expected a 1-D tensor, but got ndim={np_value.ndim}.") + return int(np_value.shape[0]) + if isinstance(expr, relax.ShapeExpr): + return len(expr.values) + if isinstance(expr, relax.PrimValue): + return 1 + struct_info = expr.struct_info + if not isinstance(struct_info, relax.TensorStructInfo): + return None + if struct_info.ndim != -1 and struct_info.ndim != 1: + raise ValueError(f"Expected a 1-D tensor, but got ndim={struct_info.ndim}.") + if struct_info.ndim != 1: + return None + if isinstance(struct_info.shape, relax.ShapeExpr): + dim = struct_info.shape.values[0] + if isinstance(dim, tirx.IntImm): + return int(dim.value) + if isinstance(dim, int): + return dim + return None + + +def _normalize_constant_axes(axes: list[int], rank: int, op_name: str) -> list[int]: + """Normalize a list of constant axes and validate their uniqueness.""" + + normalized_axes = [] + for axis in axes: + if axis < 0: + axis += rank + if axis < 0 or axis >= rank: + raise ValueError(f"{op_name} axis {axis} is out of range for rank {rank}.") + normalized_axes.append(axis) + if len(normalized_axes) != len(set(normalized_axes)): + raise ValueError(f"{op_name} axes must be unique.") + return normalized_axes + + +def _as_int64_tensor(bb: relax.BlockBuilder, expr: relax.Expr) -> relax.Expr: + """Convert a tensor-like expression to an int64 tensor expression.""" + + if isinstance(expr, relax.ShapeExpr): + return bb.normalize(relax.op.shape_to_tensor(expr)) + if isinstance(expr, relax.PrimValue): + return bb.normalize(relax.op.full((1,), expr, dtype="int64")) + if isinstance(expr, relax.Constant): + if expr.struct_info.dtype == "int64": + return expr + return bb.normalize(relax.op.astype(expr, "int64")) + if isinstance(expr.struct_info, relax.TensorStructInfo) and expr.struct_info.dtype != "int64": + return bb.normalize(relax.op.astype(expr, "int64")) + return expr + + +def _tensor_to_shape_expr( + bb: relax.BlockBuilder, shape_tensor: relax.Expr, shape_ndim: int, prefix: str +) -> relax.ShapeExpr: + """Convert a statically sized int64 tensor into a ShapeExpr.""" + + shape_tensor = bb.match_cast(shape_tensor, relax.TensorStructInfo([shape_ndim], "int64")) + shape_dataflow_var = bb.emit(relax.op.tensor_to_shape(shape_tensor)) + shape_vars = [tirx.Var(f"{prefix}_{i}", "int64") for i in range(shape_ndim)] + bb.match_cast(shape_dataflow_var, relax.ShapeStructInfo(shape_vars)) + return relax.ShapeExpr(shape_vars) + + +def _build_unsqueezed_shape_tensor( + bb: relax.BlockBuilder, data_shape_tensor: relax.Expr, axes: relax.Expr, data_ndim: int +) -> relax.Expr: + """Build the output shape tensor for Unsqueeze with runtime axes.""" + + axes = _as_int64_tensor(bb, axes) + axes_len = _get_known_tensor_length(axes) + if axes_len is None: + raise ValueError("Unsqueeze requires a statically known axes length.") + + output_ndim = data_ndim + axes_len + axes = bb.normalize( + relax.op.where( + relax.op.less(axes, relax.const(0, "int64")), + relax.op.add(axes, relax.const(output_ndim, "int64")), + axes, + ) + ) + positions = relax.op.arange(output_ndim, dtype="int64") + positions = bb.normalize(relax.op.expand_dims(positions, axis=1)) + axes = bb.normalize(relax.op.expand_dims(axes, axis=0)) + insert_mask = bb.normalize( + relax.op.sum(relax.op.astype(relax.op.equal(positions, axes), "int64"), axis=1) + ) + keep_mask = bb.normalize(relax.op.subtract(relax.const(1, "int64"), insert_mask)) + input_indices = bb.normalize( + relax.op.subtract(relax.op.cumsum(keep_mask, axis=0), relax.const(1, "int64")) + ) + safe_indices = bb.normalize( + relax.op.where( + relax.op.less(input_indices, relax.const(0, "int64")), + relax.const(0, "int64"), + input_indices, + ) + ) + kept_dims = bb.normalize(relax.op.take(data_shape_tensor, safe_indices, axis=0)) + return bb.normalize( + relax.op.where( + relax.op.greater(insert_mask, relax.const(0, "int64")), + relax.const(1, "int64"), + kept_dims, + ) + ) + + +def _build_squeezed_shape_tensor( + bb: relax.BlockBuilder, data_shape_tensor: relax.Expr, axes: relax.Expr, data_ndim: int +) -> relax.Expr: + """Build the output shape tensor for Squeeze with runtime axes.""" + + axes = _as_int64_tensor(bb, axes) + axes = bb.normalize( + relax.op.where( + relax.op.less(axes, relax.const(0, "int64")), + relax.op.add(axes, relax.const(data_ndim, "int64")), + axes, + ) + ) + positions = relax.op.arange(data_ndim, dtype="int64") + positions = bb.normalize(relax.op.expand_dims(positions, axis=1)) + axes = bb.normalize(relax.op.expand_dims(axes, axis=0)) + remove_mask = bb.normalize( + relax.op.sum(relax.op.astype(relax.op.equal(positions, axes), "int64"), axis=1) + ) + keep_mask = bb.normalize(relax.op.equal(remove_mask, relax.const(0, "int64"))) + keep_indices = bb.normalize(relax.op.nonzero(keep_mask)) + num_keep_dims = tirx.Var("squeeze_num_keep_dims", "int64") + keep_indices = bb.match_cast(keep_indices, relax.TensorStructInfo([1, num_keep_dims], "int64")) + keep_indices = bb.normalize(relax.op.reshape(keep_indices, [-1])) + return bb.normalize(relax.op.take(data_shape_tensor, keep_indices, axis=0)) + + class Slice(OnnxOpConverter): - """Converts an onnx Splice node into an equivalent Relax expression.""" + """Converts an onnx Slice node into an equivalent Relax expression.""" @classmethod def _impl_v13(cls, bb, inputs, attr, params): - # TODO (jwfromm) currently only supports constant parameters. data = inputs[0] starts = get_constant(inputs[1], params) ends = get_constant(inputs[2], params) axes = get_constant(inputs[3], params) steps = get_constant(inputs[4], params) - if not all( - [ - ( - isinstance(param, relax.Constant | relax.ShapeExpr | relax.PrimValue) - or param is None + all_constant_params = all( + isinstance(param, relax.Constant | relax.ShapeExpr | relax.PrimValue) or param is None + for param in [starts, ends, axes, steps] + ) + if all_constant_params: + starts = get_prim_expr_list(starts) + ends = get_prim_expr_list(ends) + if len(starts) != len(ends): + raise ValueError( + f"Slice expects starts and ends to have the same length, but got " + f"{len(starts)} and {len(ends)}." ) - for param in [starts, ends, axes, steps] - ] - ): - raise ValueError("Only constant Slice parameters are currently supported.") - # Convert parameters to constant lists. - starts = get_prim_expr_list(starts) - ends = get_prim_expr_list(ends) - if axes is not None: - axes = get_prim_expr_list(axes) - else: - axes = list(range(len(starts))) - # Convert negative axis to positive if needed. - for i, axis in enumerate(axes): - if axis < 0: - axes[i] = axis + len(data.struct_info.shape) - if steps is not None: - steps = get_prim_expr_list(steps) - else: - steps = [1] * len(axes) - # If input is a shape tensor, we can directly extract it. - if isinstance(data, relax.ShapeExpr): - shape_data = list(data) - # Starts, ends, and steps must be 1-d for shape operation. - assert all(len(i) == 1 for i in [starts, ends, steps]) - sliced_values = shape_data[starts[0] : ends[0] : steps[0]] - - if all([isinstance(val, tirx.IntImm | int) for val in sliced_values]): - return relax.const([x.value for x in sliced_values], "int64") + if axes is not None: + axes = get_prim_expr_list(axes) + if len(axes) != len(starts): + raise ValueError( + f"Slice expects axes and starts to have the same length, but got " + f"{len(axes)} and {len(starts)}." + ) + else: + axes = list(range(len(starts))) + + data_ndim = _get_known_tensor_rank(data) + if data_ndim is None: + raise ValueError("Slice requires a statically known input rank.") + axes = _normalize_constant_axes(list(axes), data_ndim, "Slice") + if steps is not None: + steps = get_prim_expr_list(steps) + if len(steps) != len(starts): + raise ValueError( + f"Slice expects steps and starts to have the same length, but got " + f"{len(steps)} and {len(starts)}." + ) else: + steps = [1] * len(axes) + if any( + (isinstance(step, int) and step == 0) + or (isinstance(step, tirx.IntImm) and int(step) == 0) + for step in steps + ): + raise ValueError("Slice step values must be non-zero.") + if isinstance(data, relax.ShapeExpr): + shape_data = list(data) + assert all(len(i) == 1 for i in [starts, ends, steps]) + sliced_values = shape_data[starts[0] : ends[0] : steps[0]] + + if all([isinstance(val, tirx.IntImm | int) for val in sliced_values]): + return relax.const([x.value for x in sliced_values], "int64") return relax.ShapeExpr(sliced_values) - # If all `starts`, `ends`, and `steps` are constant, use strict mode - # Otherwise, we assume the slice is inbound. - assume_inbound = not all( - [isinstance(param, tirx.IntImm | int) for param in [*starts, *ends, *steps]] - ) + assume_inbound = not all( + [isinstance(param, tirx.IntImm | int) for param in [*starts, *ends, *steps]] + ) + starts = get_prim_value_list(starts) + ends = get_prim_value_list(ends) + steps = get_prim_value_list(steps) - # Converting PrimExpr to PrimValue since relax.op.strided_slice does not accept PrimExpr - starts = get_prim_value_list(starts) - ends = get_prim_value_list(ends) - steps = get_prim_value_list(steps) + return relax.op.strided_slice( + data, axes, starts, ends, steps, assume_inbound=assume_inbound + ) + + data_ndim = _get_known_tensor_rank(data) + if data_ndim is None: + raise ValueError( + "Slice with dynamic parameters requires a statically known input rank." + ) + + data_expr = data + if isinstance(data, relax.ShapeExpr): + data_expr = bb.normalize(relax.op.shape_to_tensor(data)) + + starts_tensor = _as_int64_tensor(bb, starts) + ends_tensor = _as_int64_tensor(bb, ends) + axes_len = _get_known_tensor_length(starts_tensor) + if axes_len is None: + raise ValueError("Slice requires a statically known starts length.") + ends_len = _get_known_tensor_length(ends_tensor) + if ends_len is None: + raise ValueError("Slice requires a statically known ends length.") + if ends_len != axes_len: + raise ValueError( + f"Slice expects starts and ends to have the same length, but got " + f"{axes_len} and {ends_len}." + ) + + if axes is None: + axes_tensor = relax.op.arange(axes_len, dtype="int64") + else: + axes_tensor = _as_int64_tensor(bb, axes) + axes_tensor_len = _get_known_tensor_length(axes_tensor) + if axes_tensor_len is None: + raise ValueError("Slice requires a statically known axes length.") + if axes_tensor_len != axes_len: + raise ValueError( + f"Slice expects axes and starts to have the same length, but got " + f"{axes_tensor_len} and {axes_len}." + ) + if steps is None: + steps_tensor = relax.const(_np.ones((axes_len,), dtype="int64"), "int64") + else: + steps_tensor = _as_int64_tensor(bb, steps) + steps_len = _get_known_tensor_length(steps_tensor) + if steps_len is None: + raise ValueError("Slice requires a statically known steps length.") + if steps_len != axes_len: + raise ValueError( + f"Slice expects steps and starts to have the same length, but got " + f"{steps_len} and {axes_len}." + ) + if isinstance(steps_tensor, relax.Constant) and _np.any(steps_tensor.data.numpy() == 0): + raise ValueError("Slice step values must be non-zero.") + + axes_tensor = bb.normalize( + relax.op.where( + relax.op.less(axes_tensor, relax.const(0, "int64")), + relax.op.add(axes_tensor, relax.const(data_ndim, "int64")), + axes_tensor, + ) + ) - return relax.op.strided_slice( - data, axes, starts, ends, steps, assume_inbound=assume_inbound + data_shape = bb.normalize(relax.op.shape_of(data_expr)) + data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape)) + full_starts = relax.const(_np.zeros((data_ndim,), dtype="int64"), "int64") + full_steps = relax.const(_np.ones((data_ndim,), dtype="int64"), "int64") + full_starts = bb.normalize( + relax.op.scatter_elements(full_starts, axes_tensor, starts_tensor) + ) + full_ends = bb.normalize( + relax.op.scatter_elements(data_shape_tensor, axes_tensor, ends_tensor) ) + full_steps = bb.normalize(relax.op.scatter_elements(full_steps, axes_tensor, steps_tensor)) + return relax.op.dynamic_strided_slice(data_expr, full_starts, full_ends, full_steps) class Pad(OnnxOpConverter): @@ -2421,9 +2745,7 @@ def _impl_v20(cls, bb, inputs, attr, params): align_corners = attr.get("align_corners", 0) if align_corners != 1: - raise NotImplementedError( - "AffineGrid with align_corners=0 is not yet supported in TVM" - ) + raise NotImplementedError("AffineGrid with align_corners=0 is not yet supported in TVM") # Extract size values if isinstance(size, relax.Constant): @@ -3663,6 +3985,50 @@ def _impl_v1(cls, bb, inputs, attr, params): ) +class Optional_(OnnxOpConverter): + """Converts an ONNX Optional node into an erased or empty Optional representation.""" + + @classmethod + def _impl_v15(cls, bb, inputs, attr, params): + if len(inputs) > 1: + raise ValueError(f"Optional accepts at most one input, but got {len(inputs)}") + if len(inputs) == 0 or inputs[0] is None: + if "type" not in attr: + raise ValueError("Optional without an input must specify the type attribute.") + return _EmptyOptional(attr["type"]) + return inputs[0] + + _impl_v18 = _impl_v15 + + +class OptionalHasElement(OnnxOpConverter): + """Converts an ONNX OptionalHasElement node into a boolean constant.""" + + @classmethod + def _impl_v15(cls, bb, inputs, attr, params): + if len(inputs) != 1: + raise ValueError(f"OptionalHasElement expects one input, but got {len(inputs)}") + if inputs[0] is None or _is_empty_optional(inputs[0]): + return relax.const(False, dtype="bool") + return relax.const(True, dtype="bool") + + _impl_v18 = _impl_v15 + + +class OptionalGetElement(OnnxOpConverter): + """Converts an ONNX OptionalGetElement node by unwrapping a non-empty Optional.""" + + @classmethod + def _impl_v15(cls, bb, inputs, attr, params): + if len(inputs) != 1: + raise ValueError(f"OptionalGetElement expects one input, but got {len(inputs)}") + if inputs[0] is None or _is_empty_optional(inputs[0]): + raise ValueError("OptionalGetElement cannot access an empty optional.") + return inputs[0] + + _impl_v18 = _impl_v15 + + class SequenceConstruct(OnnxOpConverter): """Operator converter for sequence construction op.""" @@ -4034,9 +4400,9 @@ def _impl_v16(cls, bb, inputs, attr, params): def _get_convert_map(): return { # defs/experimental - # "Optional": Optional_, - # "OptionalHasElement": OptionalHasElement, - # "OptionalGetElement": OptionalGetElement, + "Optional": Optional_, + "OptionalHasElement": OptionalHasElement, + "OptionalGetElement": OptionalGetElement, # Binary operators "Add": Add, "Sub": Sub, @@ -4107,7 +4473,7 @@ def _get_convert_map(): "Gemm": Gemm, "MatMul": MatMul, # "MatMulInteger": MatMulInteger, - # "MatMulInteger16": MatMulInteger16, + "MatMulInteger16": MatMulInteger16, "Reshape": Reshape, "Sigmoid": Sigmoid, "Softmax": Softmax, @@ -4418,6 +4784,8 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto): "Squeeze", ] return_tuple_ops = [ + "Optional", + "OptionalGetElement", "SequenceConstruct", "SequenceEmpty", "SequenceErase", @@ -4436,13 +4804,16 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto): try: op = self._convert_operator(op_name, inputs, attr, self.opset) # Create struct information for the new operator. - op = self.bb.normalize(op) + if isinstance(op, relax.Expr): + op = self.bb.normalize(op) except TVMError as err: print(f"Error converting operator {op_name}, with inputs: {inputs}") raise err if op_name in return_tuple_ops: outputs_num = 1 + elif _is_empty_optional(op): + outputs_num = 1 elif not isinstance(op, relax.Tuple): if isinstance(op.struct_info, relax.TupleStructInfo): # This is a var bound to a tuple. We need to unpack it and create @@ -4488,11 +4859,11 @@ def _parse_attr(self, attr_proto: onnx.onnx_ml_pb2.AttributeProto) -> dict[str, if list(getattr(a, f)): assert a.name not in attrs, "Only one type of attr is allowed" attrs[a.name] = tuple(getattr(a, f)) - for f in ["t"]: - if a.HasField(f): + for f in ["t", "tp"]: + if hasattr(a, f) and a.HasField(f): attrs[a.name] = getattr(a, f) - for f in ["tensors"]: - if list(getattr(a, f)): + for f in ["tensors", "type_protos"]: + if hasattr(a, f) and list(getattr(a, f)): assert a.name not in attrs, "Only one type of attr is allowed" attrs[a.name] = tuple(getattr(a, f)) for f in ["graphs"]: diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 887533f26139..ae6e2b9e0cf5 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -197,6 +197,65 @@ def _check_output(tvm_out, ort_out): _check_output(tvm_out, ort_out) +def run_in_tvm( + model: ModelProto, + inputs: dict[str, np.ndarray] | None = None, + ir_version: int = 8, + opset: int = 14, +): + if ir_version is not None: + model.ir_version = ir_version + if opset is not None: + for opset_import in model.opset_import: + if opset_import.domain in ["", "ai.onnx"]: + opset_import.version = opset + break + + inputs = generate_random_inputs(model, inputs) + tvm_model = from_onnx(model, opset=opset, keep_params_in_input=True) + tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model) + tvm_model = relax.transform.LegalizeOps()(tvm_model) + tvm_model, params = relax.frontend.detach_params(tvm_model) + + with tvm.transform.PassContext(opt_level=3): + ex = tvm.compile(tvm_model, target="llvm") + vm = relax.VirtualMachine(ex, tvm.cpu()) + + input_list = [ + inputs[key.name_hint] for key in tvm_model["main"].params if key.name_hint in inputs + ] + if params: + input_list += params["main"] + + vm.set_input("main", *input_list) + vm.invoke_stateful("main") + return vm.get_outputs("main") + + +def collect_relax_call_ops(func: relax.Function) -> list[str]: + op_names = [] + + def fvisit(expr): + if isinstance(expr, relax.Call) and isinstance(expr.op, tvm.ir.Op): + op_names.append(expr.op.name) + + relax.analysis.post_order_visit(func.body, fvisit) + return op_names + + +def collect_scalar_constants(func: relax.Function) -> list[bool | int | float]: + values = [] + + def fvisit(expr): + if isinstance(expr, relax.Constant): + value = expr.data.numpy() + if value.shape == (): + values.append(value.item()) + + relax.analysis.post_order_visit(func.body, fvisit) + return values + + @pytest.mark.parametrize( "input_names, expected_names", [ @@ -374,6 +433,101 @@ def test_matmul(dynamic): check_correctness(model, inputs) +@pytest.mark.parametrize( + ("a_dtype", "b_dtype", "a_shape", "b_shape"), + [ + (np.int16, np.int16, [2, 3], [3, 4]), + (np.uint16, np.uint16, [2, 3], [3, 4]), + (np.int16, np.uint16, [2, 1, 3, 5], [1, 2, 5, 4]), + ], +) +def test_matmulinteger16(a_dtype, b_dtype, a_shape, b_shape): + a = np.arange(np.prod(a_shape), dtype=np.int64).reshape(a_shape) + b = np.arange(np.prod(b_shape), dtype=np.int64).reshape(b_shape) + if np.issubdtype(a_dtype, np.signedinteger): + a -= a.size // 2 + if np.issubdtype(b_dtype, np.signedinteger): + b -= b.size // 2 + a = a.astype(a_dtype) + b = b.astype(b_dtype) + + out_dtype = np.uint32 if a_dtype == np.uint16 and b_dtype == np.uint16 else np.int32 + expected = np.matmul(a.astype(out_dtype), b.astype(out_dtype)) + + node = helper.make_node("MatMulInteger16", ["a", "b"], ["y"], domain="com.microsoft") + graph = helper.make_graph( + [node], + "matmulinteger16_test", + inputs=[ + helper.make_tensor_value_info("a", helper.np_dtype_to_tensor_dtype(a.dtype), a_shape), + helper.make_tensor_value_info("b", helper.np_dtype_to_tensor_dtype(b.dtype), b_shape), + ], + outputs=[ + helper.make_tensor_value_info( + "y", helper.np_dtype_to_tensor_dtype(np.dtype(out_dtype)), expected.shape + ) + ], + ) + model = helper.make_model( + graph, + producer_name="matmulinteger16_test", + opset_imports=[helper.make_opsetid("", 18), helper.make_opsetid("com.microsoft", 1)], + ) + model.ir_version = 11 + + tvm_output = run_in_tvm(model, inputs={"a": a, "b": b}, ir_version=11, opset=18) + assert isinstance(tvm_output, tvm.runtime.Tensor) + assert tvm_output.numpy().dtype == out_dtype + tvm.testing.assert_allclose(tvm_output.numpy(), expected) + + +def test_matmulinteger16_ir(): + node = helper.make_node("MatMulInteger16", ["a", "b"], ["y"], domain="com.microsoft") + graph = helper.make_graph( + [node], + "matmulinteger16_ir_test", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.UINT16, [2, 3]), + helper.make_tensor_value_info("b", TensorProto.UINT16, [3, 4]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.UINT32, [2, 4])], + ) + model = helper.make_model( + graph, + producer_name="matmulinteger16_ir_test", + opset_imports=[helper.make_opsetid("", 18), helper.make_opsetid("com.microsoft", 1)], + ) + model.ir_version = 11 + + tvm_model = from_onnx(model, opset=18, keep_params_in_input=True) + call_ops = collect_relax_call_ops(tvm_model["main"]) + assert call_ops.count("relax.astype") == 2 + assert "relax.matmul" in call_ops + assert tvm_model["main"].ret_struct_info.dtype == "uint32" + + +def test_matmulinteger16_invalid_dtype_raises(): + node = helper.make_node("MatMulInteger16", ["a", "b"], ["y"], domain="com.microsoft") + graph = helper.make_graph( + [node], + "matmulinteger16_invalid_dtype_test", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.INT8, [2, 3]), + helper.make_tensor_value_info("b", TensorProto.UINT16, [3, 4]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.INT32, [2, 4])], + ) + model = helper.make_model( + graph, + producer_name="matmulinteger16_invalid_dtype_test", + opset_imports=[helper.make_opsetid("", 18), helper.make_opsetid("com.microsoft", 1)], + ) + model.ir_version = 11 + + with pytest.raises(ValueError, match="input A"): + from_onnx(model, opset=18, keep_params_in_input=True) + + def test_concat(): verify_binary("Concat", [1, 32], [1, 32], [2, 32], attrs={"axis": 0}) @@ -904,6 +1058,82 @@ def test_unsqueeze(): check_correctness(model) +def test_unsqueeze_dynamic_axes(): + unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"]) + + graph = helper.make_graph( + [unsqueeze_node], + "unsqueeze_dynamic_axes", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]), + helper.make_tensor_value_info("axes", TensorProto.INT64, [2]), + ], + outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 32, 32, 1])], + ) + + model = helper.make_model(graph, producer_name="unsqueeze_dynamic_axes_test") + inputs = { + "a": rg.standard_normal(size=[32, 32]).astype("float32"), + "axes": np.array([-1, 0], dtype="int64"), + } + check_correctness(model, inputs, opset=13) + + +def test_unsqueeze_dynamic_axes_ir(): + unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"]) + + graph = helper.make_graph( + [unsqueeze_node], + "unsqueeze_dynamic_axes_ir", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]), + helper.make_tensor_value_info("axes", TensorProto.INT64, [2]), + ], + outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 32, 32, 1])], + ) + + model = helper.make_model(graph, producer_name="unsqueeze_dynamic_axes_ir_test") + tvm_model = from_onnx(model, opset=13, keep_params_in_input=True) + call_ops = collect_relax_call_ops(tvm_model["main"]) + + assert "relax.tensor_to_shape" in call_ops + assert "relax.reshape" in call_ops + + +def test_unsqueeze_dynamic_axes_rank_validation(): + unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"]) + + graph = helper.make_graph( + [unsqueeze_node], + "unsqueeze_dynamic_axes_rank_validation", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]), + helper.make_tensor_value_info("axes", TensorProto.INT64, [1, 2]), + ], + outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 32, 32, 1])], + ) + + model = helper.make_model(graph, producer_name="unsqueeze_dynamic_axes_rank_validation_test") + with pytest.raises(ValueError, match="Expected a 1-D tensor"): + from_onnx(model, opset=13, keep_params_in_input=True) + + +def test_unsqueeze_duplicate_axes_validation(): + unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"]) + + graph = helper.make_graph( + [unsqueeze_node], + "unsqueeze_duplicate_axes_validation", + inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32])], + initializer=[helper.make_tensor("axes", TensorProto.INT64, [2], vals=[0, 0])], + outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 1, 32, 32])], + ) + + model = helper.make_model(graph, producer_name="unsqueeze_duplicate_axes_validation_test") + with pytest.raises(ValueError, match="axes must be unique"): + from_onnx(model, opset=13) + + def test_unsqueeze_v1(): # https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-1 unsqueeze_node = helper.make_node("Unsqueeze", ["a"], ["b"], axes=[0, 2, 3]) @@ -1384,6 +1614,70 @@ def test_dynamic_squeeze(axis, A, B): check_correctness(model, inputs, opset=13) +def test_squeeze_dynamic_axes(): + squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"]) + shape = [1, 32, 1, 32] + + graph = helper.make_graph( + [squeeze_node], + "squeeze_dynamic_axes_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), + helper.make_tensor_value_info("axes", TensorProto.INT64, [2]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32, 32])], + ) + + model = helper.make_model(graph, producer_name="squeeze_dynamic_axes_test") + inputs = { + "x": rg.standard_normal(size=shape).astype("float32"), + "axes": np.array([-4, 2], dtype="int64"), + } + check_correctness(model, inputs, opset=13) + + +def test_squeeze_dynamic_axes_ir(): + squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"]) + shape = [1, 32, 1, 32] + + graph = helper.make_graph( + [squeeze_node], + "squeeze_dynamic_axes_ir", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), + helper.make_tensor_value_info("axes", TensorProto.INT64, [2]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32, 32])], + ) + + model = helper.make_model(graph, producer_name="squeeze_dynamic_axes_ir_test") + tvm_model = from_onnx(model, opset=13, keep_params_in_input=True) + call_ops = collect_relax_call_ops(tvm_model["main"]) + + assert "relax.tensor_to_shape" in call_ops + assert "relax.reshape" in call_ops + assert "relax.squeeze" not in call_ops + + +def test_squeeze_dynamic_axes_rank_validation(): + squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"]) + shape = [1, 32, 1, 32] + + graph = helper.make_graph( + [squeeze_node], + "squeeze_dynamic_axes_rank_validation", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), + helper.make_tensor_value_info("axes", TensorProto.INT64, [1, 2]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32, 32])], + ) + + model = helper.make_model(graph, producer_name="squeeze_dynamic_axes_rank_validation_test") + with pytest.raises(ValueError, match="Expected a 1-D tensor"): + from_onnx(model, opset=13, keep_params_in_input=True) + + @pytest.mark.parametrize("axis", [[0]]) @pytest.mark.parametrize("A", [8, 16, 32]) def test_dynamic_shape_squeeze(axis, A): @@ -2287,6 +2581,99 @@ def verify_slice(data_shape, output_shape, starts, ends, axes=None, steps=None): # ) +def test_slice_dynamic_inputs(): + slice_node = helper.make_node("Slice", ["x", "starts", "ends", "axes", "steps"], ["y"]) + + graph = helper.make_graph( + [slice_node], + "slice_dynamic_inputs_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, [20, 10, 5]), + helper.make_tensor_value_info("starts", TensorProto.INT64, [2]), + helper.make_tensor_value_info("ends", TensorProto.INT64, [2]), + helper.make_tensor_value_info("axes", TensorProto.INT64, [2]), + helper.make_tensor_value_info("steps", TensorProto.INT64, [2]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 10, 5])], + ) + + model = helper.make_model(graph, producer_name="slice_dynamic_inputs_test") + inputs = { + "x": rg.standard_normal(size=[20, 10, 5]).astype("float32"), + "starts": np.array([0, 0], dtype="int64"), + "ends": np.array([3, 10], dtype="int64"), + "axes": np.array([0, 1], dtype="int64"), + "steps": np.array([1, 1], dtype="int64"), + } + check_correctness(model, inputs, opset=13) + + +def test_slice_dynamic_inputs_ir(): + slice_node = helper.make_node("Slice", ["x", "starts", "ends", "axes", "steps"], ["y"]) + + graph = helper.make_graph( + [slice_node], + "slice_dynamic_inputs_ir", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, [20, 10, 5]), + helper.make_tensor_value_info("starts", TensorProto.INT64, [2]), + helper.make_tensor_value_info("ends", TensorProto.INT64, [2]), + helper.make_tensor_value_info("axes", TensorProto.INT64, [2]), + helper.make_tensor_value_info("steps", TensorProto.INT64, [2]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 10, 5])], + ) + + model = helper.make_model(graph, producer_name="slice_dynamic_inputs_ir_test") + tvm_model = from_onnx(model, opset=13, keep_params_in_input=True) + call_ops = collect_relax_call_ops(tvm_model["main"]) + + assert "relax.dynamic_strided_slice" in call_ops + assert "relax.strided_slice" not in call_ops + + +def test_slice_dynamic_inputs_length_validation(): + slice_node = helper.make_node("Slice", ["x", "starts", "ends", "axes", "steps"], ["y"]) + + graph = helper.make_graph( + [slice_node], + "slice_dynamic_inputs_length_validation", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, [20, 10, 5]), + helper.make_tensor_value_info("starts", TensorProto.INT64, [2]), + helper.make_tensor_value_info("ends", TensorProto.INT64, [1]), + helper.make_tensor_value_info("axes", TensorProto.INT64, [2]), + helper.make_tensor_value_info("steps", TensorProto.INT64, [2]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 10, 5])], + ) + + model = helper.make_model(graph, producer_name="slice_dynamic_inputs_length_validation_test") + with pytest.raises(ValueError, match="starts and ends to have the same length"): + from_onnx(model, opset=13, keep_params_in_input=True) + + +def test_slice_zero_step_validation(): + slice_node = helper.make_node("Slice", ["x", "starts", "ends", "axes", "steps"], ["y"]) + + graph = helper.make_graph( + [slice_node], + "slice_zero_step_validation", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [20, 10, 5])], + initializer=[ + helper.make_tensor("starts", TensorProto.INT64, [2], vals=[0, 0]), + helper.make_tensor("ends", TensorProto.INT64, [2], vals=[3, 10]), + helper.make_tensor("axes", TensorProto.INT64, [2], vals=[0, 1]), + helper.make_tensor("steps", TensorProto.INT64, [2], vals=[1, 0]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 10, 5])], + ) + + model = helper.make_model(graph, producer_name="slice_zero_step_validation_test") + with pytest.raises(ValueError, match="step values must be non-zero"): + from_onnx(model, opset=13) + + def test_slice_dynamic_shape(): def verify_slice( data_shape, data_instance_shape, output_shape, starts, ends, axes=None, steps=None @@ -3169,6 +3556,21 @@ def make_constant_node(name: str, data_type: int, dims: list[int], vals: list[in ) +def make_optional_tensor_value_info(name: str, elem_type: int, shape: list[int]): + return helper.make_value_info( + name, helper.make_optional_type_proto(helper.make_tensor_type_proto(elem_type, shape)) + ) + + +def make_optional_sequence_value_info(name: str, elem_type: int, shape: list[int]): + return helper.make_value_info( + name, + helper.make_optional_type_proto( + helper.make_sequence_type_proto(helper.make_tensor_type_proto(elem_type, shape)) + ), + ) + + def test_sequence_construct(): node, graph_inputs = construct_sequence(input_shape=[32, 32], num_tensors=2) graph = helper.make_graph( @@ -3280,6 +3682,163 @@ def test_sequence_at(): check_correctness(model) +def test_optional_get_element_tensor(): + x_shape = [2, 3] + optional_node = helper.make_node("Optional", ["x"], ["optional"]) + get_element_node = helper.make_node("OptionalGetElement", ["optional"], ["output"]) + graph = helper.make_graph( + [optional_node, get_element_node], + "test_optional_get_element_tensor", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, x_shape)], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, x_shape)], + value_info=[make_optional_tensor_value_info("optional", TensorProto.FLOAT, x_shape)], + ) + model = helper.make_model(graph, producer_name="test_optional_get_element_tensor") + check_correctness(model, opset=18, ir_version=11) + + +def test_optional_has_element_tensor(): + x_shape = [2, 3] + optional_node = helper.make_node("Optional", ["x"], ["optional"]) + has_element_node = helper.make_node("OptionalHasElement", ["optional"], ["output"]) + graph = helper.make_graph( + [optional_node, has_element_node], + "test_optional_has_element_tensor", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, x_shape)], + outputs=[helper.make_tensor_value_info("output", TensorProto.BOOL, [])], + value_info=[make_optional_tensor_value_info("optional", TensorProto.FLOAT, x_shape)], + ) + model = helper.make_model(graph, producer_name="test_optional_has_element_tensor") + check_correctness(model, opset=18, ir_version=11) + + +def test_optional_has_element_empty(): + x_shape = [2, 3] + tensor_type = helper.make_tensor_type_proto(TensorProto.FLOAT, x_shape) + optional_type = helper.make_optional_type_proto(tensor_type) + optional_node = helper.make_node("Optional", [], ["optional"], type=tensor_type) + has_element_node = helper.make_node("OptionalHasElement", ["optional"], ["output"]) + graph = helper.make_graph( + [optional_node, has_element_node], + "test_optional_has_element_empty", + inputs=[], + outputs=[helper.make_tensor_value_info("output", TensorProto.BOOL, [])], + value_info=[helper.make_value_info("optional", optional_type)], + ) + model = helper.make_model(graph, producer_name="test_optional_has_element_empty") + check_correctness(model, opset=18, ir_version=11) + + +def test_optional_has_element_empty_ir(): + x_shape = [2, 3] + tensor_type = helper.make_tensor_type_proto(TensorProto.FLOAT, x_shape) + optional_type = helper.make_optional_type_proto(tensor_type) + optional_node = helper.make_node("Optional", [], ["optional"], type=tensor_type) + has_element_node = helper.make_node("OptionalHasElement", ["optional"], ["output"]) + graph = helper.make_graph( + [optional_node, has_element_node], + "test_optional_has_element_empty_ir", + inputs=[], + outputs=[helper.make_tensor_value_info("output", TensorProto.BOOL, [])], + value_info=[helper.make_value_info("optional", optional_type)], + ) + model = helper.make_model(graph, producer_name="test_optional_has_element_empty_ir") + model.ir_version = 11 + model.opset_import[0].version = 18 + tvm_model = from_onnx(model, opset=18, keep_params_in_input=True) + + assert collect_relax_call_ops(tvm_model["main"]) == [] + assert False in collect_scalar_constants(tvm_model["main"]) + + +def test_optional_get_element_tensor_ir(): + x_shape = [2, 3] + optional_node = helper.make_node("Optional", ["x"], ["optional"]) + get_element_node = helper.make_node("OptionalGetElement", ["optional"], ["output"]) + graph = helper.make_graph( + [optional_node, get_element_node], + "test_optional_get_element_tensor_ir", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, x_shape)], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, x_shape)], + value_info=[make_optional_tensor_value_info("optional", TensorProto.FLOAT, x_shape)], + ) + model = helper.make_model(graph, producer_name="test_optional_get_element_tensor_ir") + model.ir_version = 11 + model.opset_import[0].version = 18 + tvm_model = from_onnx(model, opset=18, keep_params_in_input=True) + + assert collect_relax_call_ops(tvm_model["main"]) == [] + assert tvm_model["main"].ret_struct_info.dtype == "float32" + + +def test_optional_get_element_sequence(): + seq_node, graph_inputs = construct_sequence(input_shape=[32, 32], num_tensors=4) + index = make_constant_node("index", TensorProto.INT64, (), [1]) + optional_node = helper.make_node("Optional", ["sequence"], ["optional"]) + get_element_node = helper.make_node("OptionalGetElement", ["optional"], ["unwrapped"]) + sequence_at_node = helper.make_node("SequenceAt", ["unwrapped", "index"], ["output"]) + graph = helper.make_graph( + [index, seq_node, optional_node, get_element_node, sequence_at_node], + "test_optional_get_element_sequence", + inputs=graph_inputs, + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [32, 32])], + value_info=[make_optional_sequence_value_info("optional", TensorProto.FLOAT, [32, 32])], + ) + model = helper.make_model(graph, producer_name="test_optional_get_element_sequence") + check_correctness(model, opset=18, ir_version=11) + + +def test_optional_without_input_requires_type_attr(): + tensor_type = helper.make_tensor_type_proto(TensorProto.FLOAT, [2, 3]) + optional_type = helper.make_optional_type_proto(tensor_type) + optional_node = helper.make_node("Optional", [], ["optional"]) + graph = helper.make_graph( + [optional_node], + "test_optional_without_input_requires_type_attr", + inputs=[], + outputs=[helper.make_value_info("optional", optional_type)], + ) + model = helper.make_model(graph, producer_name="test_optional_without_input_requires_type_attr") + model.opset_import[0].version = 18 + + with pytest.raises(ValueError, match="type attribute"): + from_onnx(model, opset=18, keep_params_in_input=True) + + +def test_optional_has_element_requires_one_input(): + has_element_node = helper.make_node("OptionalHasElement", [], ["output"]) + graph = helper.make_graph( + [has_element_node], + "test_optional_has_element_requires_one_input", + inputs=[], + outputs=[helper.make_tensor_value_info("output", TensorProto.BOOL, [])], + ) + model = helper.make_model(graph, producer_name="test_optional_has_element_requires_one_input") + model.opset_import[0].version = 18 + + with pytest.raises(ValueError, match="expects one input"): + from_onnx(model, opset=18, keep_params_in_input=True) + + +def test_optional_get_element_empty_raises(): + x_shape = [2, 3] + tensor_type = helper.make_tensor_type_proto(TensorProto.FLOAT, x_shape) + optional_type = helper.make_optional_type_proto(tensor_type) + optional_node = helper.make_node("Optional", [], ["optional"], type=tensor_type) + get_element_node = helper.make_node("OptionalGetElement", ["optional"], ["output"]) + graph = helper.make_graph( + [optional_node, get_element_node], + "test_optional_get_element_empty_raises", + inputs=[], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, x_shape)], + value_info=[helper.make_value_info("optional", optional_type)], + ) + model = helper.make_model(graph, producer_name="test_optional_get_element_empty_raises") + model.opset_import[0].version = 18 + with pytest.raises(ValueError, match="empty optional"): + from_onnx(model, opset=18, keep_params_in_input=True) + + def test_symbolic_shape_deduction(): index_node = helper.make_node( "Constant",