diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 4af7115e5ca6..a113534286e2 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -735,38 +735,59 @@ def _impl_v1(cls, bb, inputs, attr, params): def _impl_v13(cls, bb, inputs, attr, params): data = inputs[0] axes = get_constant(inputs[1], params) + data_ndim = _get_known_tensor_rank(data) - # Handle ONNX shape inference if isinstance(data, relax.PrimValue) and isinstance(axes, relax.Constant): - axes = axes.data.numpy().tolist() - if axes == [0]: + constant_axes = _normalize_constant_axes( + list(map(int, axes.data.numpy().tolist())), 1, "Unsqueeze" + ) + if constant_axes == [0]: return relax.ShapeExpr([data.value]) - else: - raise NotImplementedError( - "Unsqueeze with symbolic axes and non-zero axes is not supported." - ) - # If input is a constant, compute directly + raise NotImplementedError("Unsqueeze with symbolic scalar inputs only supports axis 0.") if isinstance(data, relax.Constant) and isinstance(axes, relax.Constant): - axes = axes.data.numpy().tolist() + constant_axes = _normalize_constant_axes( + list(map(int, axes.data.numpy().tolist())), + data.data.numpy().ndim + axes.data.numpy().size, + "Unsqueeze", + ) + constant_axes = sorted(constant_axes) expanded = data.data.numpy() - if len(expanded.shape) == 0: - # Special case implying input is a scalar, wrap it as a list. - if 0 in axes: - axes.remove(0) - expanded = [expanded] - for axis in axes: - expanded = _np.expand_dims(expanded, axis=axis) + output_rank = expanded.ndim + len(constant_axes) + new_shape = [] + input_dims_iter = iter(expanded.shape) + for i in range(output_rank): + if i in constant_axes: + new_shape.append(1) + else: + new_shape.append(next(input_dims_iter)) + expanded = expanded.reshape(new_shape) return relax.const(expanded, data.struct_info.dtype) if isinstance(axes, relax.Constant): - constant_axes = list(axes.data.numpy()) - constant_axes = list(map(int, constant_axes)) + if data_ndim is None: + raise ValueError("Unsqueeze requires a statically known input rank.") + constant_axes = _normalize_constant_axes( + list(map(int, axes.data.numpy().tolist())), + data_ndim + axes.data.numpy().size, + "Unsqueeze", + ) constant_axes = sorted(constant_axes) for axis in constant_axes: data = relax.op.expand_dims(data, axis=axis) return data - raise NotImplementedError("Unsqueeze with dynamic axes is not supported.") + if data_ndim is None: + raise ValueError("Unsqueeze with dynamic axes requires a statically known input rank.") + axes_len = _get_known_tensor_length(axes) + if axes_len is None: + raise ValueError("Unsqueeze requires a statically known axes length.") + data_shape = bb.normalize(relax.op.shape_of(data)) + data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape)) + output_shape_tensor = _build_unsqueezed_shape_tensor(bb, data_shape_tensor, axes, data_ndim) + output_shape = _tensor_to_shape_expr( + bb, output_shape_tensor, data_ndim + axes_len, "unsqueeze_dim" + ) + return relax.op.reshape(data, output_shape) class Concat(OnnxOpConverter): @@ -1487,14 +1508,37 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.const(out_data, data.struct_info.dtype) if isinstance(data, relax.ShapeExpr): - if axis == (0,): + shape_tensor_ndim = 1 + if axis is None: + if len(data) == 1: + return relax.PrimValue(data[0]) + return data + normalized_axes = _normalize_constant_axes(list(axis), shape_tensor_ndim, "Squeeze") + if normalized_axes == [0] and len(data) == 1: return relax.PrimValue(data[0]) - else: - raise NotImplementedError( - "Squeeze with symbolic axes and non-zero axes is not supported." - ) + raise NotImplementedError( + "Squeeze on symbolic shape tensors only supports removing the sole axis." + ) - return relax.op.squeeze(data, axis) + if axis is None: + return relax.op.squeeze(data) + + if isinstance(axis, tuple): + return relax.op.squeeze(data, list(axis)) + + data_ndim = _get_known_tensor_rank(data) + if data_ndim is None: + raise ValueError("Squeeze with dynamic axes requires a statically known input rank.") + axes_len = _get_known_tensor_length(axis) + if axes_len is None: + raise ValueError("Squeeze requires a statically known axes length.") + data_shape = bb.normalize(relax.op.shape_of(data)) + data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape)) + output_shape_tensor = _build_squeezed_shape_tensor(bb, data_shape_tensor, axis, data_ndim) + output_shape = _tensor_to_shape_expr( + bb, output_shape_tensor, data_ndim - axes_len, "squeeze_dim" + ) + return relax.op.reshape(data, output_shape) class Constant(OnnxOpConverter): @@ -1891,68 +1935,309 @@ def get_prim_value_list(values): return new_values +def _get_known_tensor_rank(expr: relax.Expr) -> int | None: + """Return the statically known rank of an expression when available.""" + + if isinstance(expr, relax.Constant): + return len(expr.data.numpy().shape) + if isinstance(expr, relax.ShapeExpr): + return 1 + if isinstance(expr, relax.PrimValue): + return 0 + struct_info = expr.struct_info + if isinstance(struct_info, relax.TensorStructInfo): + return None if struct_info.ndim == -1 else struct_info.ndim + return None + + +def _get_known_tensor_length(expr: relax.Expr | None) -> int | None: + """Return the statically known length of a 1-D tensor-like expression.""" + + if expr is None: + return None + if isinstance(expr, relax.Constant): + np_value = expr.data.numpy() + if np_value.ndim != 1: + raise ValueError(f"Expected a 1-D tensor, but got ndim={np_value.ndim}.") + return int(np_value.shape[0]) + if isinstance(expr, relax.ShapeExpr): + return len(expr.values) + if isinstance(expr, relax.PrimValue): + return 1 + struct_info = expr.struct_info + if not isinstance(struct_info, relax.TensorStructInfo): + return None + if struct_info.ndim == -1: + return None + if struct_info.ndim != 1: + raise ValueError(f"Expected a 1-D tensor, but got ndim={struct_info.ndim}.") + if isinstance(struct_info.shape, relax.ShapeExpr): + dim = struct_info.shape.values[0] + if isinstance(dim, tirx.IntImm): + return int(dim.value) + if isinstance(dim, int): + return dim + return None + + +def _normalize_constant_axes(axes: list[int], rank: int, op_name: str) -> list[int]: + """Normalize a list of constant axes and validate their uniqueness.""" + + normalized_axes = [] + for axis in axes: + original_axis = axis + if axis < 0: + axis += rank + if axis < 0 or axis >= rank: + raise ValueError(f"{op_name} axis {original_axis} is out of range for rank {rank}.") + normalized_axes.append(axis) + if len(normalized_axes) != len(set(normalized_axes)): + raise ValueError(f"{op_name} axes must be unique.") + return normalized_axes + + +def _as_int64_tensor(bb: relax.BlockBuilder, expr: relax.Expr) -> relax.Expr: + """Convert a tensor-like expression to an int64 tensor expression.""" + + if isinstance(expr, relax.ShapeExpr): + return bb.normalize(relax.op.shape_to_tensor(expr)) + if isinstance(expr, relax.PrimValue): + return bb.normalize(relax.op.full((1,), expr, dtype="int64")) + if isinstance(expr, relax.Constant): + if expr.struct_info.dtype == "int64": + return expr + return bb.normalize(relax.op.astype(expr, "int64")) + if isinstance(expr.struct_info, relax.TensorStructInfo) and expr.struct_info.dtype != "int64": + return bb.normalize(relax.op.astype(expr, "int64")) + return expr + + +def _tensor_to_shape_expr( + bb: relax.BlockBuilder, shape_tensor: relax.Expr, shape_ndim: int, prefix: str +) -> relax.ShapeExpr: + """Convert a statically sized int64 tensor into a ShapeExpr.""" + + shape_tensor = bb.match_cast(shape_tensor, relax.TensorStructInfo([shape_ndim], "int64")) + shape_dataflow_var = bb.emit(relax.op.tensor_to_shape(shape_tensor)) + shape_vars = [tirx.Var(f"{prefix}_{i}", "int64") for i in range(shape_ndim)] + bb.match_cast(shape_dataflow_var, relax.ShapeStructInfo(shape_vars)) + return relax.ShapeExpr(shape_vars) + + +def _build_unsqueezed_shape_tensor( + bb: relax.BlockBuilder, data_shape_tensor: relax.Expr, axes: relax.Expr, data_ndim: int +) -> relax.Expr: + """Build the output shape tensor for Unsqueeze with runtime axes.""" + + axes = _as_int64_tensor(bb, axes) + axes_len = _get_known_tensor_length(axes) + if axes_len is None: + raise ValueError("Unsqueeze requires a statically known axes length.") + + output_ndim = data_ndim + axes_len + axes = bb.normalize( + relax.op.where( + relax.op.less(axes, relax.const(0, "int64")), + relax.op.add(axes, relax.const(output_ndim, "int64")), + axes, + ) + ) + positions = relax.op.arange(output_ndim, dtype="int64") + positions = bb.normalize(relax.op.expand_dims(positions, axis=1)) + axes = bb.normalize(relax.op.expand_dims(axes, axis=0)) + insert_mask = bb.normalize( + relax.op.sum(relax.op.astype(relax.op.equal(positions, axes), "int64"), axis=1) + ) + keep_mask = bb.normalize(relax.op.subtract(relax.const(1, "int64"), insert_mask)) + input_indices = bb.normalize( + relax.op.subtract(relax.op.cumsum(keep_mask, axis=0), relax.const(1, "int64")) + ) + safe_indices = bb.normalize( + relax.op.where( + relax.op.less(input_indices, relax.const(0, "int64")), + relax.const(0, "int64"), + input_indices, + ) + ) + kept_dims = bb.normalize(relax.op.take(data_shape_tensor, safe_indices, axis=0)) + return bb.normalize( + relax.op.where( + relax.op.greater(insert_mask, relax.const(0, "int64")), + relax.const(1, "int64"), + kept_dims, + ) + ) + + +def _build_squeezed_shape_tensor( + bb: relax.BlockBuilder, data_shape_tensor: relax.Expr, axes: relax.Expr, data_ndim: int +) -> relax.Expr: + """Build the output shape tensor for Squeeze with runtime axes.""" + + axes = _as_int64_tensor(bb, axes) + axes = bb.normalize( + relax.op.where( + relax.op.less(axes, relax.const(0, "int64")), + relax.op.add(axes, relax.const(data_ndim, "int64")), + axes, + ) + ) + positions = relax.op.arange(data_ndim, dtype="int64") + positions = bb.normalize(relax.op.expand_dims(positions, axis=1)) + axes = bb.normalize(relax.op.expand_dims(axes, axis=0)) + remove_mask = bb.normalize( + relax.op.sum(relax.op.astype(relax.op.equal(positions, axes), "int64"), axis=1) + ) + keep_mask = bb.normalize(relax.op.equal(remove_mask, relax.const(0, "int64"))) + keep_indices = bb.normalize(relax.op.nonzero(keep_mask)) + num_keep_dims = tirx.Var("squeeze_num_keep_dims", "int64") + keep_indices = bb.match_cast(keep_indices, relax.TensorStructInfo([1, num_keep_dims], "int64")) + keep_indices = bb.normalize(relax.op.reshape(keep_indices, [-1])) + return bb.normalize(relax.op.take(data_shape_tensor, keep_indices, axis=0)) + + class Slice(OnnxOpConverter): - """Converts an onnx Splice node into an equivalent Relax expression.""" + """Converts an onnx Slice node into an equivalent Relax expression.""" @classmethod def _impl_v13(cls, bb, inputs, attr, params): - # TODO (jwfromm) currently only supports constant parameters. data = inputs[0] starts = get_constant(inputs[1], params) ends = get_constant(inputs[2], params) axes = get_constant(inputs[3], params) steps = get_constant(inputs[4], params) - if not all( - [ - ( - isinstance(param, relax.Constant | relax.ShapeExpr | relax.PrimValue) - or param is None + all_constant_params = all( + isinstance(param, relax.Constant | relax.ShapeExpr | relax.PrimValue) or param is None + for param in [starts, ends, axes, steps] + ) + if all_constant_params: + starts = get_prim_expr_list(starts) + ends = get_prim_expr_list(ends) + if len(starts) != len(ends): + raise ValueError( + f"Slice expects starts and ends to have the same length, but got " + f"{len(starts)} and {len(ends)}." ) - for param in [starts, ends, axes, steps] - ] - ): - raise ValueError("Only constant Slice parameters are currently supported.") - # Convert parameters to constant lists. - starts = get_prim_expr_list(starts) - ends = get_prim_expr_list(ends) - if axes is not None: - axes = get_prim_expr_list(axes) - else: - axes = list(range(len(starts))) - # Convert negative axis to positive if needed. - for i, axis in enumerate(axes): - if axis < 0: - axes[i] = axis + len(data.struct_info.shape) - if steps is not None: - steps = get_prim_expr_list(steps) - else: - steps = [1] * len(axes) - # If input is a shape tensor, we can directly extract it. - if isinstance(data, relax.ShapeExpr): - shape_data = list(data) - # Starts, ends, and steps must be 1-d for shape operation. - assert all(len(i) == 1 for i in [starts, ends, steps]) - sliced_values = shape_data[starts[0] : ends[0] : steps[0]] - - if all([isinstance(val, tirx.IntImm | int) for val in sliced_values]): - return relax.const([x.value for x in sliced_values], "int64") + if axes is not None: + axes = get_prim_expr_list(axes) + if len(axes) != len(starts): + raise ValueError( + f"Slice expects axes and starts to have the same length, but got " + f"{len(axes)} and {len(starts)}." + ) else: + axes = list(range(len(starts))) + + data_ndim = _get_known_tensor_rank(data) + if data_ndim is None: + raise ValueError("Slice requires a statically known input rank.") + axes = _normalize_constant_axes(list(axes), data_ndim, "Slice") + if steps is not None: + steps = get_prim_expr_list(steps) + if len(steps) != len(starts): + raise ValueError( + f"Slice expects steps and starts to have the same length, but got " + f"{len(steps)} and {len(starts)}." + ) + else: + steps = [1] * len(axes) + if any( + (isinstance(step, int) and step == 0) + or (isinstance(step, tirx.IntImm) and int(step) == 0) + for step in steps + ): + raise ValueError("Slice step values must be non-zero.") + if isinstance(data, relax.ShapeExpr): + shape_data = list(data) + assert all(len(i) == 1 for i in [starts, ends, steps]) + sliced_values = shape_data[starts[0] : ends[0] : steps[0]] + + if all([isinstance(val, tirx.IntImm | int) for val in sliced_values]): + return relax.const([x.value for x in sliced_values], "int64") return relax.ShapeExpr(sliced_values) - # If all `starts`, `ends`, and `steps` are constant, use strict mode - # Otherwise, we assume the slice is inbound. - assume_inbound = not all( - [isinstance(param, tirx.IntImm | int) for param in [*starts, *ends, *steps]] - ) + assume_inbound = not all( + [isinstance(param, tirx.IntImm | int) for param in [*starts, *ends, *steps]] + ) + starts = get_prim_value_list(starts) + ends = get_prim_value_list(ends) + steps = get_prim_value_list(steps) - # Converting PrimExpr to PrimValue since relax.op.strided_slice does not accept PrimExpr - starts = get_prim_value_list(starts) - ends = get_prim_value_list(ends) - steps = get_prim_value_list(steps) + return relax.op.strided_slice( + data, axes, starts, ends, steps, assume_inbound=assume_inbound + ) + + data_ndim = _get_known_tensor_rank(data) + if data_ndim is None: + raise ValueError( + "Slice with dynamic parameters requires a statically known input rank." + ) + + if isinstance(data, relax.ShapeExpr): + raise ValueError("Slice with dynamic parameters does not support ShapeExpr input.") + data_expr = data + + starts_tensor = _as_int64_tensor(bb, starts) + ends_tensor = _as_int64_tensor(bb, ends) + axes_len = _get_known_tensor_length(starts_tensor) + if axes_len is None: + raise ValueError("Slice requires a statically known starts length.") + ends_len = _get_known_tensor_length(ends_tensor) + if ends_len is None: + raise ValueError("Slice requires a statically known ends length.") + if ends_len != axes_len: + raise ValueError( + f"Slice expects starts and ends to have the same length, but got " + f"{axes_len} and {ends_len}." + ) - return relax.op.strided_slice( - data, axes, starts, ends, steps, assume_inbound=assume_inbound + if axes is None: + axes_tensor = relax.op.arange(axes_len, dtype="int64") + else: + axes_tensor = _as_int64_tensor(bb, axes) + axes_tensor_len = _get_known_tensor_length(axes_tensor) + if axes_tensor_len is None: + raise ValueError("Slice requires a statically known axes length.") + if axes_tensor_len != axes_len: + raise ValueError( + f"Slice expects axes and starts to have the same length, but got " + f"{axes_tensor_len} and {axes_len}." + ) + if steps is None: + steps_tensor = relax.const(_np.ones((axes_len,), dtype="int64"), "int64") + else: + steps_tensor = _as_int64_tensor(bb, steps) + steps_len = _get_known_tensor_length(steps_tensor) + if steps_len is None: + raise ValueError("Slice requires a statically known steps length.") + if steps_len != axes_len: + raise ValueError( + f"Slice expects steps and starts to have the same length, but got " + f"{steps_len} and {axes_len}." + ) + if isinstance(steps_tensor, relax.Constant) and _np.any(steps_tensor.data.numpy() == 0): + raise ValueError("Slice step values must be non-zero.") + + axes_tensor = bb.normalize( + relax.op.where( + relax.op.less(axes_tensor, relax.const(0, "int64")), + relax.op.add(axes_tensor, relax.const(data_ndim, "int64")), + axes_tensor, + ) + ) + + data_shape = bb.normalize(relax.op.shape_of(data_expr)) + data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape)) + full_starts = relax.const(_np.zeros((data_ndim,), dtype="int64"), "int64") + full_steps = relax.const(_np.ones((data_ndim,), dtype="int64"), "int64") + full_starts = bb.normalize( + relax.op.scatter_elements(full_starts, axes_tensor, starts_tensor) + ) + full_ends = bb.normalize( + relax.op.scatter_elements(data_shape_tensor, axes_tensor, ends_tensor) ) + full_steps = bb.normalize(relax.op.scatter_elements(full_steps, axes_tensor, steps_tensor)) + return relax.op.dynamic_strided_slice(data_expr, full_starts, full_ends, full_steps) class Pad(OnnxOpConverter): diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index c848ef91d6a1..8b1292617896 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -233,14 +233,14 @@ def run_in_tvm( def collect_relax_call_ops(func: relax.Function) -> list[str]: - op_names = [] + op_names: list[str] = [] - def fvisit(expr): + def fvisit(expr: relax.Expr) -> None: if isinstance(expr, relax.Call) and isinstance(expr.op, tvm.ir.Op): op_names.append(expr.op.name) relax.analysis.post_order_visit(func.body, fvisit) - return op_names + return list(op_names) def collect_scalar_constants(func: relax.Function) -> list[bool | int | float]: @@ -1058,6 +1058,98 @@ def test_unsqueeze(): check_correctness(model) +def test_unsqueeze_scalar_input(): + unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"]) + + graph = helper.make_graph( + [unsqueeze_node], + "unsqueeze_scalar_input", + inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, [])], + initializer=[helper.make_tensor("axes", TensorProto.INT64, [2], vals=[0, 1])], + outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 1])], + ) + + model = helper.make_model(graph, producer_name="unsqueeze_scalar_input_test") + inputs = {"a": np.array(3.0, dtype="float32")} + check_correctness(model, inputs, opset=13) + + +def test_unsqueeze_dynamic_axes(): + unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"]) + + graph = helper.make_graph( + [unsqueeze_node], + "unsqueeze_dynamic_axes", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]), + helper.make_tensor_value_info("axes", TensorProto.INT64, [2]), + ], + outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 32, 32, 1])], + ) + + model = helper.make_model(graph, producer_name="unsqueeze_dynamic_axes_test") + inputs = { + "a": rg.standard_normal(size=[32, 32]).astype("float32"), + "axes": np.array([-1, 0], dtype="int64"), + } + check_correctness(model, inputs, opset=13) + + +def test_unsqueeze_dynamic_axes_ir(): + unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"]) + + graph = helper.make_graph( + [unsqueeze_node], + "unsqueeze_dynamic_axes_ir", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]), + helper.make_tensor_value_info("axes", TensorProto.INT64, [2]), + ], + outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 32, 32, 1])], + ) + + model = helper.make_model(graph, producer_name="unsqueeze_dynamic_axes_ir_test") + tvm_model = from_onnx(model, opset=13, keep_params_in_input=True) + call_ops = collect_relax_call_ops(tvm_model["main"]) + + assert "relax.tensor_to_shape" in call_ops + assert "relax.reshape" in call_ops + + +def test_unsqueeze_dynamic_axes_rank_validation(): + unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"]) + + graph = helper.make_graph( + [unsqueeze_node], + "unsqueeze_dynamic_axes_rank_validation", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]), + helper.make_tensor_value_info("axes", TensorProto.INT64, [1, 2]), + ], + outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 32, 32, 1])], + ) + + model = helper.make_model(graph, producer_name="unsqueeze_dynamic_axes_rank_validation_test") + with pytest.raises(ValueError, match="Expected a 1-D tensor"): + from_onnx(model, opset=13, keep_params_in_input=True) + + +def test_unsqueeze_duplicate_axes_validation(): + unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"]) + + graph = helper.make_graph( + [unsqueeze_node], + "unsqueeze_duplicate_axes_validation", + inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32])], + initializer=[helper.make_tensor("axes", TensorProto.INT64, [2], vals=[0, 0])], + outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 1, 32, 32])], + ) + + model = helper.make_model(graph, producer_name="unsqueeze_duplicate_axes_validation_test") + with pytest.raises(ValueError, match="axes must be unique"): + from_onnx(model, opset=13) + + def test_unsqueeze_v1(): # https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-1 unsqueeze_node = helper.make_node("Unsqueeze", ["a"], ["b"], axes=[0, 2, 3]) @@ -1545,6 +1637,70 @@ def test_dynamic_squeeze(axis, A, B): check_correctness(model, inputs, opset=13) +def test_squeeze_dynamic_axes(): + squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"]) + shape = [1, 32, 1, 32] + + graph = helper.make_graph( + [squeeze_node], + "squeeze_dynamic_axes_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), + helper.make_tensor_value_info("axes", TensorProto.INT64, [2]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32, 32])], + ) + + model = helper.make_model(graph, producer_name="squeeze_dynamic_axes_test") + inputs = { + "x": rg.standard_normal(size=shape).astype("float32"), + "axes": np.array([-4, 2], dtype="int64"), + } + check_correctness(model, inputs, opset=13) + + +def test_squeeze_dynamic_axes_ir(): + squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"]) + shape = [1, 32, 1, 32] + + graph = helper.make_graph( + [squeeze_node], + "squeeze_dynamic_axes_ir", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), + helper.make_tensor_value_info("axes", TensorProto.INT64, [2]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32, 32])], + ) + + model = helper.make_model(graph, producer_name="squeeze_dynamic_axes_ir_test") + tvm_model = from_onnx(model, opset=13, keep_params_in_input=True) + call_ops = collect_relax_call_ops(tvm_model["main"]) + + assert "relax.tensor_to_shape" in call_ops + assert "relax.reshape" in call_ops + assert "relax.squeeze" not in call_ops + + +def test_squeeze_dynamic_axes_rank_validation(): + squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"]) + shape = [1, 32, 1, 32] + + graph = helper.make_graph( + [squeeze_node], + "squeeze_dynamic_axes_rank_validation", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), + helper.make_tensor_value_info("axes", TensorProto.INT64, [1, 2]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32, 32])], + ) + + model = helper.make_model(graph, producer_name="squeeze_dynamic_axes_rank_validation_test") + with pytest.raises(ValueError, match="Expected a 1-D tensor"): + from_onnx(model, opset=13, keep_params_in_input=True) + + @pytest.mark.parametrize("axis", [[0]]) @pytest.mark.parametrize("A", [8, 16, 32]) def test_dynamic_shape_squeeze(axis, A): @@ -2448,6 +2604,121 @@ def verify_slice(data_shape, output_shape, starts, ends, axes=None, steps=None): # ) +def test_slice_dynamic_inputs(): + slice_node = helper.make_node("Slice", ["x", "starts", "ends", "axes", "steps"], ["y"]) + + graph = helper.make_graph( + [slice_node], + "slice_dynamic_inputs_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, [20, 10, 5]), + helper.make_tensor_value_info("starts", TensorProto.INT64, [2]), + helper.make_tensor_value_info("ends", TensorProto.INT64, [2]), + helper.make_tensor_value_info("axes", TensorProto.INT64, [2]), + helper.make_tensor_value_info("steps", TensorProto.INT64, [2]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 10, 5])], + ) + + model = helper.make_model(graph, producer_name="slice_dynamic_inputs_test") + inputs = { + "x": rg.standard_normal(size=[20, 10, 5]).astype("float32"), + "starts": np.array([0, 0], dtype="int64"), + "ends": np.array([3, 10], dtype="int64"), + "axes": np.array([0, 1], dtype="int64"), + "steps": np.array([1, 1], dtype="int64"), + } + check_correctness(model, inputs, opset=13) + + +def test_slice_dynamic_inputs_ir(): + slice_node = helper.make_node("Slice", ["x", "starts", "ends", "axes", "steps"], ["y"]) + + graph = helper.make_graph( + [slice_node], + "slice_dynamic_inputs_ir", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, [20, 10, 5]), + helper.make_tensor_value_info("starts", TensorProto.INT64, [2]), + helper.make_tensor_value_info("ends", TensorProto.INT64, [2]), + helper.make_tensor_value_info("axes", TensorProto.INT64, [2]), + helper.make_tensor_value_info("steps", TensorProto.INT64, [2]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 10, 5])], + ) + + model = helper.make_model(graph, producer_name="slice_dynamic_inputs_ir_test") + tvm_model = from_onnx(model, opset=13, keep_params_in_input=True) + call_ops = collect_relax_call_ops(tvm_model["main"]) + + assert "relax.dynamic_strided_slice" in call_ops + assert "relax.strided_slice" not in call_ops + + +def test_slice_dynamic_inputs_length_validation(): + slice_node = helper.make_node("Slice", ["x", "starts", "ends", "axes", "steps"], ["y"]) + + graph = helper.make_graph( + [slice_node], + "slice_dynamic_inputs_length_validation", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, [20, 10, 5]), + helper.make_tensor_value_info("starts", TensorProto.INT64, [2]), + helper.make_tensor_value_info("ends", TensorProto.INT64, [1]), + helper.make_tensor_value_info("axes", TensorProto.INT64, [2]), + helper.make_tensor_value_info("steps", TensorProto.INT64, [2]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 10, 5])], + ) + + model = helper.make_model(graph, producer_name="slice_dynamic_inputs_length_validation_test") + with pytest.raises(ValueError, match="starts and ends to have the same length"): + from_onnx(model, opset=13, keep_params_in_input=True) + + +def test_slice_dynamic_shape_expr_input_validation(): + shape_node = helper.make_node("Shape", ["x"], ["y"]) + slice_node = helper.make_node("Slice", ["y", "starts", "ends", "axes", "steps"], ["z"]) + + graph = helper.make_graph( + [shape_node, slice_node], + "slice_dynamic_shape_expr_input_validation", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, [20, 10, 5]), + helper.make_tensor_value_info("starts", TensorProto.INT64, [1]), + helper.make_tensor_value_info("ends", TensorProto.INT64, [1]), + helper.make_tensor_value_info("axes", TensorProto.INT64, [1]), + helper.make_tensor_value_info("steps", TensorProto.INT64, [1]), + ], + outputs=[helper.make_tensor_value_info("z", TensorProto.INT64, [1])], + ) + + model = helper.make_model(graph, producer_name="slice_dynamic_shape_expr_input_validation_test") + with pytest.raises(ValueError, match="does not support ShapeExpr input"): + from_onnx(model, opset=13, keep_params_in_input=True) + + +def test_slice_zero_step_validation(): + slice_node = helper.make_node("Slice", ["x", "starts", "ends", "axes", "steps"], ["y"]) + + graph = helper.make_graph( + [slice_node], + "slice_zero_step_validation", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [20, 10, 5])], + initializer=[ + helper.make_tensor("starts", TensorProto.INT64, [2], vals=[0, 0]), + helper.make_tensor("ends", TensorProto.INT64, [2], vals=[3, 10]), + helper.make_tensor("axes", TensorProto.INT64, [2], vals=[0, 1]), + helper.make_tensor("steps", TensorProto.INT64, [2], vals=[1, 0]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 10, 5])], + ) + + model = helper.make_model(graph, producer_name="slice_zero_step_validation_test") + with pytest.raises(ValueError, match="step values must be non-zero"): + from_onnx(model, opset=13) + + def test_slice_dynamic_shape(): def verify_slice( data_shape, data_instance_shape, output_shape, starts, ends, axes=None, steps=None