[Relax][ONNX] Fix shape/dynamic restrictions for Unsqueeze/Squeeze and Slice#18954
[Relax][ONNX] Fix shape/dynamic restrictions for Unsqueeze/Squeeze and Slice#18954Aharrypotter wants to merge 6 commits intoapache:mainfrom
Unsqueeze/Squeeze and Slice#18954Conversation
…ctural tests
- fix Slice converter docstring typo ("Splice" -> "Slice") for consistency
- add explicit validation to reject zero step values in Slice
for both constant and dynamic-parameter paths
- add Unsqueeze negative test to reject duplicate axes
- strengthen structural IR test for dynamic Slice to assert
relax.dynamic_strided_slice is used and relax.strided_slice is not
- add Slice negative test for zero-step input
Validation:
- python -m ruff check python/tvm/relax/frontend/onnx/onnx_frontend.py tests/python/relax/test_frontend_onnx.py
- python -m pre_commit run --files python/tvm/relax/frontend/onnx/onnx_frontend.py tests/python/relax/test_frontend_onnx.py
- python -m pytest -n 1 tests/python/relax/test_frontend_onnx.py -k "unsqueeze_dynamic_axes or unsqueeze_duplicate_axes_validation or slice_dynamic_inputs_ir or slice_dynamic_inputs_length_validation or slice_zero_step_validation" -v
Result:
- 7 passed
There was a problem hiding this comment.
Code Review
This pull request enhances the ONNX frontend for TVM Relax by adding support for MatMulInteger16 and ONNX Optional operators, and updating Unsqueeze, Squeeze, and Slice to handle dynamic axes and parameters. New utility functions for tensor rank and length analysis were introduced to facilitate these dynamic operations. Feedback suggests simplifying the constant Unsqueeze implementation to more robustly handle scalar inputs and adding runtime assertions to validate that dynamic Slice steps are non-zero.
| expanded = data.data.numpy() | ||
| if len(expanded.shape) == 0: | ||
| # Special case implying input is a scalar, wrap it as a list. | ||
| if 0 in axes: | ||
| axes.remove(0) | ||
| expanded = [expanded] | ||
| for axis in axes: | ||
| constant_axes = [axis - 1 for axis in constant_axes if axis != 0] | ||
| for axis in constant_axes: | ||
| expanded = _np.expand_dims(expanded, axis=axis) |
There was a problem hiding this comment.
The logic for handling constant Unsqueeze is overly complex and appears to have a bug when dealing with scalar inputs. The special handling for scalars can lead to incorrect behavior due to how numpy arrays and lists are treated.
A simpler and more robust approach is to calculate the target shape and then use reshape. This works correctly for both scalar and tensor inputs and is easier to understand.
expanded = data.data.numpy()
output_rank = expanded.ndim + len(constant_axes)
# Create a target shape with 1s at the unsqueezed axes
new_shape = [1] * output_rank
input_dims_iter = iter(expanded.shape)
for i in range(output_rank):
if i not in constant_axes:
new_shape[i] = next(input_dims_iter)
# Reshape to the new shape. This is equivalent to a series of expand_dims.
expanded = expanded.reshape(new_shape)| if isinstance(steps_tensor, relax.Constant) and _np.any(steps_tensor.data.numpy() == 0): | ||
| raise ValueError("Slice step values must be non-zero.") |
There was a problem hiding this comment.
The check for zero-step values is only performed for constant steps. According to the ONNX specification, a step value of 0 is invalid. If steps is a dynamic tensor, a zero value could still be passed at runtime, potentially leading to invalid IR or a runtime error. A runtime assertion should be added to handle this case for dynamic steps.
if isinstance(steps_tensor, relax.Constant):
if _np.any(steps_tensor.data.numpy() == 0):
raise ValueError("Slice step values must be non-zero.")
else:
# Add a runtime check for dynamic steps
is_zero = relax.op.equal(steps_tensor, relax.const(0, "int64"))
bb.emit(relax.op.Assert(relax.op.all(relax.op.logical_not(is_zero)), "Slice step values must be non-zero."))
Description
This PR addresses the shape and dynamic restriction fixes for
Unsqueeze,Squeeze, andSliceoperators in the Relax ONNX frontend, as part of the tracking issue #18945.Previously, these operators had limited support for symbolic axes or dynamic parameters. This PR enables more robust dynamic shape handling and aligns the implementation with TVM's Relax operator standards.
Relates to #18945.
Changes
starts,ends, andsteps.relax.dynamic_strided_sliceis correctly emitted when parameters are not constant.Sliceconverter docstring ("Splice" -> "Slice").Testing & Validation
relax.dynamic_strided_sliceis used for dynamic paths.Unsqueezeand zero-step inSlice.Test Command:
python -m pytest -n 1 tests/python/relax/test_frontend_onnx.py -k "unsqueeze or squeeze or slice"Result: All tests passed.