Skip to content

[Relax][ONNX] Fix shape/dynamic restrictions for Unsqueeze/Squeeze and Slice#18954

Closed
Aharrypotter wants to merge 6 commits intoapache:mainfrom
Aharrypotter:relax-onnx-dynamic-unsqueeze-squeeze-slice
Closed

[Relax][ONNX] Fix shape/dynamic restrictions for Unsqueeze/Squeeze and Slice#18954
Aharrypotter wants to merge 6 commits intoapache:mainfrom
Aharrypotter:relax-onnx-dynamic-unsqueeze-squeeze-slice

Conversation

@Aharrypotter
Copy link
Copy Markdown
Contributor

Description

This PR addresses the shape and dynamic restriction fixes for Unsqueeze, Squeeze, and Slice operators in the Relax ONNX frontend, as part of the tracking issue #18945.

Previously, these operators had limited support for symbolic axes or dynamic parameters. This PR enables more robust dynamic shape handling and aligns the implementation with TVM's Relax operator standards.

Relates to #18945.

Changes

  • Unsqueeze:
    • Enabled support for dynamic axes.
    • Added validation to reject duplicate axes (as per ONNX spec).
    • Tightened validation for symbolic scalar inputs.
  • Squeeze:
    • Improved internal shape tensor building to better handle dynamic cases.
  • Slice:
    • Full support for dynamic/symbolic starts, ends, and steps.
    • Ensured relax.dynamic_strided_slice is correctly emitted when parameters are not constant.
    • Added explicit validation to reject zero-step values (prevents infinite loops/invalid IR).
  • Cleanup: Fixed a typo in the Slice converter docstring ("Splice" -> "Slice").

Testing & Validation

  • Added structural IR tests to verify that relax.dynamic_strided_slice is used for dynamic paths.
  • Added negative tests for edge cases: duplicate axes in Unsqueeze and zero-step in Slice.
  • Verified all new and existing ONNX frontend tests for these operators pass.

Test Command:

python -m pytest -n 1 tests/python/relax/test_frontend_onnx.py -k "unsqueeze or squeeze or slice"

Result: All tests passed.

…ctural tests

- fix Slice converter docstring typo ("Splice" -> "Slice") for consistency
- add explicit validation to reject zero step values in Slice
  for both constant and dynamic-parameter paths
- add Unsqueeze negative test to reject duplicate axes
- strengthen structural IR test for dynamic Slice to assert
  relax.dynamic_strided_slice is used and relax.strided_slice is not
- add Slice negative test for zero-step input

Validation:
- python -m ruff check python/tvm/relax/frontend/onnx/onnx_frontend.py tests/python/relax/test_frontend_onnx.py
- python -m pre_commit run --files python/tvm/relax/frontend/onnx/onnx_frontend.py tests/python/relax/test_frontend_onnx.py
- python -m pytest -n 1 tests/python/relax/test_frontend_onnx.py -k "unsqueeze_dynamic_axes or unsqueeze_duplicate_axes_validation or slice_dynamic_inputs_ir or slice_dynamic_inputs_length_validation or slice_zero_step_validation" -v

Result:
- 7 passed
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enhances the ONNX frontend for TVM Relax by adding support for MatMulInteger16 and ONNX Optional operators, and updating Unsqueeze, Squeeze, and Slice to handle dynamic axes and parameters. New utility functions for tensor rank and length analysis were introduced to facilitate these dynamic operations. Feedback suggests simplifying the constant Unsqueeze implementation to more robustly handle scalar inputs and adding runtime assertions to validate that dynamic Slice steps are non-zero.

Comment on lines 753 to 758
expanded = data.data.numpy()
if len(expanded.shape) == 0:
# Special case implying input is a scalar, wrap it as a list.
if 0 in axes:
axes.remove(0)
expanded = [expanded]
for axis in axes:
constant_axes = [axis - 1 for axis in constant_axes if axis != 0]
for axis in constant_axes:
expanded = _np.expand_dims(expanded, axis=axis)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for handling constant Unsqueeze is overly complex and appears to have a bug when dealing with scalar inputs. The special handling for scalars can lead to incorrect behavior due to how numpy arrays and lists are treated.

A simpler and more robust approach is to calculate the target shape and then use reshape. This works correctly for both scalar and tensor inputs and is easier to understand.

            expanded = data.data.numpy()
            output_rank = expanded.ndim + len(constant_axes)

            # Create a target shape with 1s at the unsqueezed axes
            new_shape = [1] * output_rank
            input_dims_iter = iter(expanded.shape)
            for i in range(output_rank):
                if i not in constant_axes:
                    new_shape[i] = next(input_dims_iter)

            # Reshape to the new shape. This is equivalent to a series of expand_dims.
            expanded = expanded.reshape(new_shape)

Comment on lines +2210 to +2211
if isinstance(steps_tensor, relax.Constant) and _np.any(steps_tensor.data.numpy() == 0):
raise ValueError("Slice step values must be non-zero.")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check for zero-step values is only performed for constant steps. According to the ONNX specification, a step value of 0 is invalid. If steps is a dynamic tensor, a zero value could still be passed at runtime, potentially leading to invalid IR or a runtime error. A runtime assertion should be added to handle this case for dynamic steps.

            if isinstance(steps_tensor, relax.Constant):
                if _np.any(steps_tensor.data.numpy() == 0):
                    raise ValueError("Slice step values must be non-zero.")
            else:
                # Add a runtime check for dynamic steps
                is_zero = relax.op.equal(steps_tensor, relax.const(0, "int64"))
                bb.emit(relax.op.Assert(relax.op.all(relax.op.logical_not(is_zero)), "Slice step values must be non-zero."))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant