diff --git a/include/tvm/relax/attrs/vision.h b/include/tvm/relax/attrs/vision.h index 4e3351bb90c8..69ce458e7e35 100644 --- a/include/tvm/relax/attrs/vision.h +++ b/include/tvm/relax/attrs/vision.h @@ -73,6 +73,66 @@ struct ROIAlignAttrs : public AttrsNodeReflAdapter { TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIAlignAttrs", ROIAlignAttrs, BaseAttrsNode); }; // struct ROIAlignAttrs +/*! \brief Attributes used in GetValidCounts operator */ +struct GetValidCountsAttrs : public AttrsNodeReflAdapter { + double score_threshold; + int id_index; + int score_index; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("score_threshold", &GetValidCountsAttrs::score_threshold, + "Lower limit of score for valid bounding boxes.") + .def_ro("id_index", &GetValidCountsAttrs::id_index, + "Index of the class categories, -1 to disable.") + .def_ro("score_index", &GetValidCountsAttrs::score_index, + "Index of the scores/confidence of boxes."); + } + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GetValidCountsAttrs", GetValidCountsAttrs, + BaseAttrsNode); +}; // struct GetValidCountsAttrs + +/*! \brief Attributes used in NonMaximumSuppression operator */ +struct NonMaximumSuppressionAttrs + : public AttrsNodeReflAdapter { + int max_output_size; + double iou_threshold; + bool force_suppress; + int top_k; + int coord_start; + int score_index; + int id_index; + bool return_indices; + bool invalid_to_bottom; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("max_output_size", &NonMaximumSuppressionAttrs::max_output_size, + "Max number of output valid boxes, -1 for no limit.") + .def_ro("iou_threshold", &NonMaximumSuppressionAttrs::iou_threshold, + "Non-maximum suppression IoU threshold.") + .def_ro("force_suppress", &NonMaximumSuppressionAttrs::force_suppress, + "Whether to suppress all detections regardless of class_id.") + .def_ro("top_k", &NonMaximumSuppressionAttrs::top_k, + "Keep maximum top k detections before nms, -1 for no limit.") + .def_ro("coord_start", &NonMaximumSuppressionAttrs::coord_start, + "Start index of the consecutive 4 coordinates.") + .def_ro("score_index", &NonMaximumSuppressionAttrs::score_index, + "Index of the scores/confidence of boxes.") + .def_ro("id_index", &NonMaximumSuppressionAttrs::id_index, + "Index of the class categories, -1 to disable.") + .def_ro("return_indices", &NonMaximumSuppressionAttrs::return_indices, + "Whether to return box indices in input data.") + .def_ro("invalid_to_bottom", &NonMaximumSuppressionAttrs::invalid_to_bottom, + "Whether to move all valid bounding boxes to the top."); + } + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.NonMaximumSuppressionAttrs", + NonMaximumSuppressionAttrs, BaseAttrsNode); +}; // struct NonMaximumSuppressionAttrs + + /*! \brief Attributes for multibox_transform_loc (SSD / TFLite-style box decode). */ struct MultiboxTransformLocAttrs : public AttrsNodeReflAdapter { bool clip; diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index ee1a2c24206e..0b8dc4e7de59 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -157,7 +157,13 @@ tanh, trunc, ) -from .vision import all_class_non_max_suppression, multibox_transform_loc, roi_align +from .vision import ( + all_class_non_max_suppression, + get_valid_counts, + multibox_transform_loc, + non_max_suppression, + roi_align, +) def _register_op_make(): diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index e8c91f04b459..d85c439d3ae2 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -246,6 +246,16 @@ class AllClassNonMaximumSuppressionAttrs(Attrs): """Attributes for vision.all_class_non_max_suppression""" +@tvm_ffi.register_object("relax.attrs.GetValidCountsAttrs") +class GetValidCountsAttrs(Attrs): + """Attributes for vision.get_valid_counts""" + + +@tvm_ffi.register_object("relax.attrs.NonMaximumSuppressionAttrs") +class NonMaximumSuppressionAttrs(Attrs): + """Attributes for vision.non_max_suppression""" + + @tvm_ffi.register_object("relax.attrs.ROIAlignAttrs") class ROIAlignAttrs(Attrs): """Attributes for vision.roi_align""" diff --git a/python/tvm/relax/op/vision/nms.py b/python/tvm/relax/op/vision/nms.py index 616c74ddf604..4eb3eb7f7a78 100644 --- a/python/tvm/relax/op/vision/nms.py +++ b/python/tvm/relax/op/vision/nms.py @@ -14,9 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Non-maximum suppression operator""" +"""Non-maximum suppression operators.""" -# from tvm import relax # Unused import from . import _ffi_api @@ -72,3 +71,114 @@ def all_class_non_max_suppression( return _ffi_api.all_class_non_max_suppression( boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_format ) + + +def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): + """Get valid count of bounding boxes given a score threshold. + Also moves valid boxes to the top of input data. + + Parameters + ---------- + data : relax.Expr + 3-D tensor with shape [batch_size, num_anchors, elem_length]. + + score_threshold : float, optional + Lower limit of score for valid bounding boxes. + + id_index : int, optional + Index of the class categories. Set to ``-1`` to disable the class-id check. + + score_index : int, optional + Index of the scores/confidence of boxes. + + Returns + ------- + out : relax.Expr + A tuple ``(valid_count, out_tensor, out_indices)`` where ``valid_count`` + has shape ``[batch_size]``, ``out_tensor`` has shape + ``[batch_size, num_anchors, elem_length]``, and ``out_indices`` has shape + ``[batch_size, num_anchors]``. + """ + return _ffi_api.get_valid_counts(data, score_threshold, id_index, score_index) + + +def non_max_suppression( + data, + valid_count, + indices, + max_output_size=-1, + iou_threshold=0.5, + force_suppress=False, + top_k=-1, + coord_start=2, + score_index=1, + id_index=0, + return_indices=True, + invalid_to_bottom=False, +): + """Non-maximum suppression operator for object detection. + + Parameters + ---------- + data : relax.Expr + 3-D tensor with shape [batch_size, num_anchors, elem_length]. + + valid_count : relax.Expr + 1-D tensor for valid number of boxes. + + indices : relax.Expr + 2-D tensor with shape [batch_size, num_anchors]. + + max_output_size : int, optional + Max number of output valid boxes, -1 for no limit. + + iou_threshold : float, optional + Non-maximum suppression IoU threshold. + + force_suppress : bool, optional + Whether to suppress all detections regardless of class_id. When + ``id_index`` is ``-1``, all valid boxes are treated as belonging to the + same class, so this flag has the same effect as ``True``. + + top_k : int, optional + Keep maximum top k detections before nms, -1 for no limit. + + coord_start : int, optional + Start index of the consecutive 4 coordinates. + + score_index : int, optional + Index of the scores/confidence of boxes. + + id_index : int, optional + Index of the class categories. Set to ``-1`` to suppress boxes across + all classes. + + return_indices : bool, optional + Whether to return box indices in input data. + + invalid_to_bottom : bool, optional + Whether to move valid bounding boxes to the top of the returned tensor. + This option only affects the ``return_indices=False`` path. + + Returns + ------- + out : relax.Expr + If ``return_indices`` is ``True``, returns + ``(box_indices, valid_box_count)`` with shapes + ``[batch_size, num_anchors]`` and ``[batch_size, 1]``. + Otherwise returns the modified data tensor. + """ + return _ffi_api.non_max_suppression( + data, + valid_count, + indices, + max_output_size, + iou_threshold, + force_suppress, + top_k, + coord_start, + score_index, + id_index, + return_indices, + invalid_to_bottom, + ) diff --git a/python/tvm/relax/transform/legalize_ops/vision.py b/python/tvm/relax/transform/legalize_ops/vision.py index 28367a67a361..ea0458bfcef0 100644 --- a/python/tvm/relax/transform/legalize_ops/vision.py +++ b/python/tvm/relax/transform/legalize_ops/vision.py @@ -120,6 +120,36 @@ def _roi_align(bb: BlockBuilder, call: Call) -> Expr: ) +@register_legalize("relax.vision.get_valid_counts") +def _get_valid_counts(block_builder: BlockBuilder, call: Call) -> Expr: + return block_builder.call_te( + topi.vision.get_valid_counts, + call.args[0], + score_threshold=call.attrs.score_threshold, + id_index=call.attrs.id_index, + score_index=call.attrs.score_index, + ) + + +@register_legalize("relax.vision.non_max_suppression") +def _non_max_suppression(block_builder: BlockBuilder, call: Call) -> Expr: + return block_builder.call_te( + topi.vision.non_max_suppression, + call.args[0], + call.args[1], + call.args[2], + max_output_size=call.attrs.max_output_size, + iou_threshold=call.attrs.iou_threshold, + force_suppress=call.attrs.force_suppress, + top_k=call.attrs.top_k, + coord_start=call.attrs.coord_start, + score_index=call.attrs.score_index, + id_index=call.attrs.id_index, + return_indices=call.attrs.return_indices, + invalid_to_bottom=call.attrs.invalid_to_bottom, + ) + + @register_legalize("relax.vision.multibox_transform_loc") def _multibox_transform_loc(bb: BlockBuilder, call: Call) -> Expr: variances = tuple(float(x) for x in call.attrs.variances) diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index d9fd005921ed..143ccb845921 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -54,9 +54,11 @@ from .l2_normalize_python import l2_normalize_python from .gather_python import gather_python from .gather_nd_python import gather_nd_python +from .get_valid_counts_python import get_valid_counts_python from .strided_slice_python import strided_slice_python, strided_set_python from .batch_matmul import batch_matmul from .batch_norm import batch_norm +from .nms_python import non_max_suppression_python from .slice_axis_python import slice_axis_python from .sequence_mask_python import sequence_mask from .poolnd_python import poolnd_python diff --git a/python/tvm/topi/testing/get_valid_counts_python.py b/python/tvm/topi/testing/get_valid_counts_python.py new file mode 100644 index 000000000000..2caab6babc9d --- /dev/null +++ b/python/tvm/topi/testing/get_valid_counts_python.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Numpy reference implementation for get_valid_counts.""" +import numpy as np + + +def get_valid_counts_python(data, score_threshold=0, id_index=0, score_index=1): + """Numpy reference for get_valid_counts. + + Parameters + ---------- + data : numpy.ndarray + 3-D array with shape [batch_size, num_anchors, elem_length]. + + score_threshold : float + Lower limit of score for valid bounding boxes. + + id_index : int + Index of the class categories, -1 to disable. + + score_index : int + Index of the scores/confidence of boxes. + + Returns + ------- + valid_count : numpy.ndarray + 1-D array, shape [batch_size]. + + out_tensor : numpy.ndarray + Rearranged data, shape [batch_size, num_anchors, elem_length]. + + out_indices : numpy.ndarray + Indices mapping, shape [batch_size, num_anchors]. + """ + batch_size, num_anchors, box_data_length = data.shape + valid_count = np.zeros(batch_size, dtype="int32") + out_tensor = np.full_like(data, -1.0) + out_indices = np.full((batch_size, num_anchors), -1, dtype="int32") + + for i in range(batch_size): + cnt = 0 + for j in range(num_anchors): + score = data[i, j, score_index] + if id_index < 0: + is_valid = score > score_threshold + else: + is_valid = score > score_threshold and data[i, j, id_index] >= 0 + if is_valid: + out_tensor[i, cnt, :] = data[i, j, :] + out_indices[i, cnt] = j + cnt += 1 + valid_count[i] = cnt + + return valid_count, out_tensor, out_indices diff --git a/python/tvm/topi/testing/nms_python.py b/python/tvm/topi/testing/nms_python.py new file mode 100644 index 000000000000..7c8c20f5b412 --- /dev/null +++ b/python/tvm/topi/testing/nms_python.py @@ -0,0 +1,146 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Numpy reference implementation for classic non_max_suppression.""" +import numpy as np + + +def _iou(box_a, box_b, coord_start): + """Compute IoU between two boxes.""" + a = box_a[coord_start : coord_start + 4] + b = box_b[coord_start : coord_start + 4] + + a_l, a_t, a_r, a_b = min(a[0], a[2]), min(a[1], a[3]), max(a[0], a[2]), max(a[1], a[3]) + b_l, b_t, b_r, b_b = min(b[0], b[2]), min(b[1], b[3]), max(b[0], b[2]), max(b[1], b[3]) + + w = max(0.0, min(a_r, b_r) - max(a_l, b_l)) + h = max(0.0, min(a_b, b_b) - max(a_t, b_t)) + area = w * h + u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area + return 0.0 if u <= 0 else area / u + + +def non_max_suppression_python( + data, + valid_count, + indices, + max_output_size=-1, + iou_threshold=0.5, + force_suppress=False, + top_k=-1, + coord_start=2, + score_index=1, + id_index=0, + return_indices=True, + invalid_to_bottom=False, +): + """Numpy reference for classic non_max_suppression. + + Parameters + ---------- + data : numpy.ndarray + 3-D array, shape [batch_size, num_anchors, elem_length]. + + valid_count : numpy.ndarray + 1-D array, shape [batch_size]. + + indices : numpy.ndarray + 2-D array, shape [batch_size, num_anchors]. + + Returns + ------- + If return_indices is True: (box_indices, valid_box_count) + Otherwise: modified data tensor + """ + batch_size, num_anchors, _ = data.shape + out_data = np.full_like(data, -1.0) + out_box_indices = np.full((batch_size, num_anchors), -1, dtype="int32") + compacted = np.full((batch_size, num_anchors), -1, dtype="int32") + valid_box_count = np.zeros((batch_size, 1), dtype="int32") + + for i in range(batch_size): + nkeep = int(valid_count[i]) + if 0 < top_k < nkeep: + nkeep = top_k + + # Sort by score descending + scores = data[i, :nkeep, score_index].copy() + sorted_idx = np.argsort(-scores) + + # Copy sorted boxes + for j in range(nkeep): + src = sorted_idx[j] + out_data[i, j, :] = data[i, src, :] + out_box_indices[i, j] = src + + # Greedy NMS + num_valid = 0 + for j in range(nkeep): + if out_data[i, j, score_index] <= 0: + out_data[i, j, :] = -1.0 + out_box_indices[i, j] = -1 + continue + if 0 < max_output_size <= num_valid: + out_data[i, j, :] = -1.0 + out_box_indices[i, j] = -1 + continue + + num_valid += 1 + + # Suppress overlapping boxes + for k in range(j + 1, nkeep): + if out_data[i, k, score_index] <= 0: + continue + + do_suppress = False + if force_suppress: + do_suppress = True + elif id_index >= 0: + do_suppress = out_data[i, j, id_index] == out_data[i, k, id_index] + else: + do_suppress = True + + if do_suppress: + iou = _iou(out_data[i, j], out_data[i, k], coord_start) + if iou >= iou_threshold: + out_data[i, k, score_index] = -1.0 + out_box_indices[i, k] = -1 + + if return_indices: + # Compact valid indices to top and remap to original + cnt = 0 + for j in range(num_anchors): + if out_box_indices[i, j] >= 0: + orig_idx = out_box_indices[i, j] + compacted[i, cnt] = int(indices[i, orig_idx]) + cnt += 1 + valid_box_count[i, 0] = cnt + + if return_indices: + return [compacted, valid_box_count] + + if invalid_to_bottom: + # Rearrange valid boxes to top + result = np.full_like(data, -1.0) + for i in range(batch_size): + cnt = 0 + for j in range(num_anchors): + if out_data[i, j, score_index] >= 0: + result[i, cnt, :] = out_data[i, j, :] + cnt += 1 + return result + + return out_data diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index 9bdedc35352c..b69e9c2aa1e3 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -36,37 +36,510 @@ ) -def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): # pylint: disable=unused-argument +def _get_valid_counts_ir( + data, score_threshold, id_index, score_index, valid_count, out_tensor, out_indices +): + """IR for get_valid_counts. Filters boxes by score and compacts valid ones to the top.""" + batch_size = data.shape[0] + num_anchors = data.shape[1] + box_data_length = data.shape[2] + + with IRBuilder() as ib: + data = T.buffer_proxy(data) + valid_count = T.buffer_proxy(valid_count) + out_tensor = T.buffer_proxy(out_tensor) + out_indices = T.buffer_proxy(out_indices) + + with T.parallel(0, batch_size) as i: + valid_count[i] = T.int32(0) + + with T.serial(0, num_anchors) as j: + score = data[i, j, score_index] + if id_index < 0: + is_valid = score > score_threshold + else: + is_valid = tvm.tirx.all(score > score_threshold, data[i, j, id_index] >= 0) + + with T.If(is_valid): + with T.Then(): + cur = valid_count[i] + with T.serial(0, box_data_length) as k: + out_tensor[i, cur, k] = data[i, j, k] + out_indices[i, cur] = j + valid_count[i] = cur + 1 + + # Fill remaining slots with -1 + with T.serial(0, num_anchors) as j: + with T.If(j >= valid_count[i]): + with T.Then(): + with T.serial(0, box_data_length) as k: + out_tensor[i, j, k] = tvm.tirx.Cast(data.dtype, T.float32(-1.0)) + out_indices[i, j] = T.int32(-1) + + return ib.get() + + +def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): """Get valid count of bounding boxes given a score threshold. Also moves valid boxes to the top of input data. + Parameters ---------- data : tvm.te.Tensor - Input data. 3-D tensor with shape [batch_size, num_anchors, 6] - or [batch_size, num_anchors, 5]. + Input data. 3-D tensor with shape [batch_size, num_anchors, elem_length]. + score_threshold : optional, float Lower limit of score for valid bounding boxes. + id_index : optional, int - index of the class categories, -1 to disable. + Index of the class categories, -1 to disable. + score_index: optional, int Index of the scores/confidence of boxes. + Returns ------- valid_count : tvm.te.Tensor - 1-D tensor for valid number of boxes. + 1-D tensor for valid number of boxes, shape [batch_size]. + out_tensor : tvm.te.Tensor - Rearranged data tensor. - out_indices: tvm.te.Tensor or numpy NDArray - Related index in input data. + Rearranged data tensor, shape [batch_size, num_anchors, elem_length]. + + out_indices: tvm.te.Tensor + Related index in input data, shape [batch_size, num_anchors]. """ - if isinstance(score_threshold, float | int): + batch_size = data.shape[0] + num_anchors = data.shape[1] + box_data_length = data.shape[2] + + is_score_threshold_tensor = isinstance(score_threshold, te.Tensor) + if not is_score_threshold_tensor: score_threshold = tvm.tirx.const(score_threshold, dtype=data.dtype) - # id_index_const = tvm.tirx.const(id_index, "int32") # Unused - # score_index_const = tvm.tirx.const(score_index, "int32") # Unused - return ( - te.compute((data.shape[0],), lambda i: data.shape[1], name="valid_count"), - data, - te.compute((data.shape[0], data.shape[1]), lambda i, j: j, name="out_indices"), + + id_index_const = tvm.tirx.const(id_index, "int32") + score_index_const = tvm.tirx.const(score_index, "int32") + + valid_count_buf = tvm.tirx.decl_buffer((batch_size,), "int32", "valid_count") + out_tensor_buf = tvm.tirx.decl_buffer( + (batch_size, num_anchors, box_data_length), data.dtype, "out_tensor" + ) + out_indices_buf = tvm.tirx.decl_buffer( + (batch_size, num_anchors), "int32", "out_indices" + ) + + if is_score_threshold_tensor: + score_thresh_buf = tvm.tirx.decl_buffer( + score_threshold.shape, score_threshold.dtype, "score_threshold" + ) + valid_count, out_tensor, out_indices = te.extern( + [(batch_size,), (batch_size, num_anchors, box_data_length), (batch_size, num_anchors)], + [data, score_threshold], + lambda ins, outs: _get_valid_counts_ir( + ins[0], ins[1], id_index_const, score_index_const, + outs[0], outs[1], outs[2], + ), + dtype=["int32", data.dtype, "int32"], + out_buffers=[valid_count_buf, out_tensor_buf, out_indices_buf], + in_buffers=[ + tvm.tirx.decl_buffer(data.shape, data.dtype, "data"), + score_thresh_buf, + ], + name="get_valid_counts", + tag="get_valid_counts", + ) + else: + # score_threshold is a TIR constant, not a tensor + def _ir_with_const_threshold(ins, outs): + return _get_valid_counts_ir( + ins[0], score_threshold, id_index_const, score_index_const, + outs[0], outs[1], outs[2], + ) + + valid_count, out_tensor, out_indices = te.extern( + [(batch_size,), (batch_size, num_anchors, box_data_length), (batch_size, num_anchors)], + [data], + _ir_with_const_threshold, + dtype=["int32", data.dtype, "int32"], + out_buffers=[valid_count_buf, out_tensor_buf, out_indices_buf], + in_buffers=[tvm.tirx.decl_buffer(data.shape, data.dtype, "data")], + name="get_valid_counts", + tag="get_valid_counts", + ) + + return valid_count, out_tensor, out_indices + + +def _classic_nms_ir( + data, + sorted_index, + valid_count, + indices, + batch_size, + num_anchors, + box_data_length, + max_output_size, + iou_threshold, + force_suppress, + top_k, + coord_start, + score_index, + id_index, + return_indices, + out_data, + out_box_indices, + out_valid_box_count, +): + """IR for classic single-class non-maximum suppression.""" + with IRBuilder() as ib: + data = T.buffer_proxy(data) + sorted_index = T.buffer_proxy(sorted_index) + valid_count = T.buffer_proxy(valid_count) + indices = T.buffer_proxy(indices) + out_data = T.buffer_proxy(out_data) + out_box_indices = T.buffer_proxy(out_box_indices) + if out_valid_box_count is not None: + out_valid_box_count = T.buffer_proxy(out_valid_box_count) + + with T.parallel(0, batch_size) as i: + # Step 1: Reorder data by sorted score + nkeep_buf = T.alloc_buffer((1,), "int32", scope="local") + nkeep_local = T.buffer_proxy(nkeep_buf) + nkeep_local[0] = valid_count[i] + with T.If(tvm.tirx.all(top_k > 0, top_k < nkeep_local[0])): + with T.Then(): + nkeep_local[0] = top_k + + # Copy sorted boxes to output + with T.serial(0, num_anchors) as j: + with T.If(j < nkeep_local[0]): + with T.Then(): + src_idx = sorted_index[i, j] + with T.serial(0, box_data_length) as k: + out_data[i, j, k] = data[i, src_idx, k] + out_box_indices[i, j] = sorted_index[i, j] + with T.Else(): + with T.serial(0, box_data_length) as k: + out_data[i, j, k] = tvm.tirx.Cast(data.dtype, T.float32(-1.0)) + out_box_indices[i, j] = T.int32(-1) + + # Step 2: Apply NMS - greedy suppression + num_valid_boxes_buf = T.alloc_buffer((1,), "int32", scope="local") + num_valid_boxes = T.buffer_proxy(num_valid_boxes_buf) + num_valid_boxes[0] = T.int32(0) + + with T.serial(0, nkeep_local[0]) as j: + # Check if box j is still valid (score > 0) and within max_output_size + with T.If( + tvm.tirx.all( + out_data[i, j, score_index] > tvm.tirx.Cast(data.dtype, T.float32(0.0)), + tvm.tirx.Select( + max_output_size > 0, + num_valid_boxes[0] < max_output_size, + tvm.tirx.const(True), + ), + ) + ): + with T.Then(): + num_valid_boxes[0] = num_valid_boxes[0] + 1 + + # Suppress overlapping boxes + with T.serial(0, nkeep_local[0]) as k: + with T.If( + tvm.tirx.all( + k > j, + out_data[i, k, score_index] + > tvm.tirx.Cast(data.dtype, T.float32(0.0)), + ) + ): + with T.Then(): + # Check class ID match (or force_suppress) + do_suppress = tvm.tirx.const(False) + if force_suppress: + do_suppress = tvm.tirx.const(True) + elif id_index >= 0: + do_suppress = ( + out_data[i, j, id_index] == out_data[i, k, id_index] + ) + else: + do_suppress = tvm.tirx.const(True) + + with T.If(do_suppress): + with T.Then(): + # Calculate IoU + a_l = tvm.te.min( + out_data[i, j, coord_start], + out_data[i, j, coord_start + 2], + ) + a_t = tvm.te.min( + out_data[i, j, coord_start + 1], + out_data[i, j, coord_start + 3], + ) + a_r = tvm.te.max( + out_data[i, j, coord_start], + out_data[i, j, coord_start + 2], + ) + a_b = tvm.te.max( + out_data[i, j, coord_start + 1], + out_data[i, j, coord_start + 3], + ) + + b_l = tvm.te.min( + out_data[i, k, coord_start], + out_data[i, k, coord_start + 2], + ) + b_t = tvm.te.min( + out_data[i, k, coord_start + 1], + out_data[i, k, coord_start + 3], + ) + b_r = tvm.te.max( + out_data[i, k, coord_start], + out_data[i, k, coord_start + 2], + ) + b_b = tvm.te.max( + out_data[i, k, coord_start + 1], + out_data[i, k, coord_start + 3], + ) + + w = tvm.te.max( + tvm.tirx.Cast(data.dtype, T.float32(0.0)), + tvm.te.min(a_r, b_r) - tvm.te.max(a_l, b_l), + ) + h = tvm.te.max( + tvm.tirx.Cast(data.dtype, T.float32(0.0)), + tvm.te.min(a_b, b_b) - tvm.te.max(a_t, b_t), + ) + area = h * w + u = ( + (a_r - a_l) * (a_b - a_t) + + (b_r - b_l) * (b_b - b_t) + - area + ) + iou = tvm.tirx.Select( + u <= tvm.tirx.Cast(data.dtype, T.float32(0.0)), + tvm.tirx.Cast(data.dtype, T.float32(0.0)), + area / u, + ) + + with T.If(iou >= iou_threshold): + with T.Then(): + out_data[i, k, score_index] = tvm.tirx.Cast( + data.dtype, T.float32(-1.0) + ) + out_box_indices[i, k] = T.int32(-1) + + with T.Else(): + # Box suppressed or beyond max_output_size + with T.serial(0, box_data_length) as k: + out_data[i, j, k] = tvm.tirx.Cast(data.dtype, T.float32(-1.0)) + out_box_indices[i, j] = T.int32(-1) + + # Step 3: If return_indices, remap to original indices + if return_indices: + if out_valid_box_count is not None: + # Count valid boxes and remap indices + valid_idx_buf = T.alloc_buffer((1,), "int32", scope="local") + valid_idx = T.buffer_proxy(valid_idx_buf) + valid_idx[0] = T.int32(0) + + with T.serial(0, num_anchors) as j: + with T.If(out_box_indices[i, j] >= 0): + with T.Then(): + orig_idx = out_box_indices[i, j] + out_box_indices[i, valid_idx[0]] = indices[i, orig_idx] + valid_idx[0] = valid_idx[0] + 1 + + out_valid_box_count[i, 0] = valid_idx[0] + + # Fill remaining with -1 + with T.serial(0, num_anchors) as j: + with T.If(j >= valid_idx[0]): + with T.Then(): + out_box_indices[i, j] = T.int32(-1) + + return ib.get() + + +def non_max_suppression( + data, + valid_count, + indices, + max_output_size=-1, + iou_threshold=0.5, + force_suppress=False, + top_k=-1, + coord_start=2, + score_index=1, + id_index=0, + return_indices=True, + invalid_to_bottom=False, +): + """Non-maximum suppression operator for object detection. + + Parameters + ---------- + data : tvm.te.Tensor + 3-D tensor with shape [batch_size, num_anchors, elem_length]. + + valid_count : tvm.te.Tensor + 1-D tensor for valid number of boxes, shape [batch_size]. + + indices : tvm.te.Tensor + 2-D tensor with shape [batch_size, num_anchors]. + + max_output_size : optional, int + Max number of output valid boxes for each instance. + Return all valid boxes if the value is less than 0. + + iou_threshold : optional, float + Non-maximum suppression IoU threshold. + + force_suppress : optional, boolean + Whether to suppress all detections regardless of class_id. When + ``id_index`` is ``-1``, all valid boxes are treated as belonging to the + same class, so this flag has the same effect as ``True``. + + top_k : optional, int + Keep maximum top k detections before nms, -1 for no limit. + + coord_start : required, int + Start index of the consecutive 4 coordinates. + + score_index: optional, int + Index of the scores/confidence of boxes. + + id_index : optional, int + Index of the class categories, -1 to disable. + + return_indices : optional, boolean + Whether to return box indices in input data. + + invalid_to_bottom : optional, boolean + Whether to move all valid bounding boxes to the top. + + Returns + ------- + out : tvm.te.Tensor or tuple of tvm.te.Tensor + If return_indices is True, returns a tuple of (box_indices, valid_box_count). + Otherwise returns the modified data tensor. + """ + batch_size = data.shape[0] + num_anchors = data.shape[1] + box_data_length = data.shape[2] + + if isinstance(max_output_size, int): + max_output_size = tvm.tirx.const(max_output_size, dtype="int32") + if isinstance(iou_threshold, (float, int)): + iou_threshold = tvm.tirx.const(iou_threshold, dtype=data.dtype) + + # Sort by score + score_shape = (batch_size, num_anchors) + score_tensor = te.compute( + score_shape, lambda i, j: data[i, j, score_index], name="score_tensor" + ) + sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False) + + data_buf = tvm.tirx.decl_buffer(data.shape, data.dtype, "data") + sort_buf = tvm.tirx.decl_buffer(sort_tensor.shape, sort_tensor.dtype, "sorted_index") + valid_count_buf = tvm.tirx.decl_buffer(valid_count.shape, valid_count.dtype, "valid_count") + indices_buf = tvm.tirx.decl_buffer(indices.shape, indices.dtype, "indices") + + out_data_buf = tvm.tirx.decl_buffer(data.shape, data.dtype, "out_data") + out_box_indices_buf = tvm.tirx.decl_buffer( + (batch_size, num_anchors), "int32", "out_box_indices" + ) + + if return_indices: + out_valid_box_count_buf = tvm.tirx.decl_buffer( + (batch_size, 1), "int32", "out_valid_box_count" + ) + + out_data, out_box_indices, out_valid_box_count = te.extern( + [data.shape, (batch_size, num_anchors), (batch_size, 1)], + [data, sort_tensor, valid_count, indices], + lambda ins, outs: _classic_nms_ir( + ins[0], ins[1], ins[2], ins[3], + batch_size, num_anchors, box_data_length, + max_output_size, iou_threshold, + force_suppress, top_k, + coord_start, score_index, id_index, + return_indices, + outs[0], outs[1], outs[2], + ), + dtype=[data.dtype, "int32", "int32"], + out_buffers=[out_data_buf, out_box_indices_buf, out_valid_box_count_buf], + in_buffers=[data_buf, sort_buf, valid_count_buf, indices_buf], + name="non_max_suppression", + tag="non_max_suppression", + ) + return [out_box_indices, out_valid_box_count] + + out_data, out_box_indices = te.extern( + [data.shape, (batch_size, num_anchors)], + [data, sort_tensor, valid_count, indices], + lambda ins, outs: _classic_nms_ir( + ins[0], ins[1], ins[2], ins[3], + batch_size, num_anchors, box_data_length, + max_output_size, iou_threshold, + force_suppress, top_k, + coord_start, score_index, id_index, + return_indices, + outs[0], outs[1], None, + ), + dtype=[data.dtype, "int32"], + out_buffers=[out_data_buf, out_box_indices_buf], + in_buffers=[data_buf, sort_buf, valid_count_buf, indices_buf], + name="non_max_suppression", + tag="non_max_suppression", + ) + + if invalid_to_bottom: + # Rearrange to move valid boxes to top + return _rearrange_out(out_data, batch_size, num_anchors, box_data_length, score_index) + + return out_data + + +def _rearrange_out(data, batch_size, num_anchors, box_data_length, score_index): + """Move valid boxes (score >= 0) to the top of output.""" + out_buf = tvm.tirx.decl_buffer( + (batch_size, num_anchors, box_data_length), data.dtype, "rearranged" + ) + + def _rearrange_ir(ins, outs): + with IRBuilder() as ib: + data = T.buffer_proxy(ins[0]) + out = T.buffer_proxy(outs[0]) + + with T.parallel(0, batch_size) as i: + valid_idx_buf = T.alloc_buffer((1,), "int32", scope="local") + valid_idx = T.buffer_proxy(valid_idx_buf) + valid_idx[0] = T.int32(0) + + with T.serial(0, num_anchors) as j: + with T.If( + data[i, j, score_index] >= tvm.tirx.Cast(data.dtype, T.float32(0.0)) + ): + with T.Then(): + with T.serial(0, box_data_length) as k: + out[i, valid_idx[0], k] = data[i, j, k] + valid_idx[0] = valid_idx[0] + 1 + + with T.serial(0, num_anchors) as j: + with T.If(j >= valid_idx[0]): + with T.Then(): + with T.serial(0, box_data_length) as k: + out[i, j, k] = tvm.tirx.Cast(data.dtype, T.float32(-1.0)) + + return ib.get() + + return te.extern( + [(batch_size, num_anchors, box_data_length)], + [data], + _rearrange_ir, + dtype=[data.dtype], + out_buffers=[out_buf], + name="rearrange_out", + tag="rearrange_out", ) diff --git a/python/tvm/topi/vision/nms_util.py b/python/tvm/topi/vision/nms_util.py index ae1716897069..a4b4c78363a3 100644 --- a/python/tvm/topi/vision/nms_util.py +++ b/python/tvm/topi/vision/nms_util.py @@ -303,8 +303,8 @@ def _all_class_nms_ir( if selected_scores is not None: selected_scores = T.buffer_proxy(selected_scores) - if isinstance(iou_threshold, float): - iou_threshold = tvm.tirx.FloatImm("float32", iou_threshold) + if isinstance(iou_threshold, (float, int)): + iou_threshold = tvm.tirx.FloatImm("float32", float(iou_threshold)) elif isinstance(iou_threshold, te.Tensor): if len(iou_threshold.shape) == 0: iou_threshold = iou_threshold() diff --git a/src/relax/op/vision/nms.cc b/src/relax/op/vision/nms.cc index 294cd40c4515..97508d721189 100644 --- a/src/relax/op/vision/nms.cc +++ b/src/relax/op/vision/nms.cc @@ -18,6 +18,7 @@ */ #include "nms.h" +#include #include #include #include @@ -33,7 +34,11 @@ namespace tvm { namespace relax { -TVM_FFI_STATIC_INIT_BLOCK() { AllClassNonMaximumSuppressionAttrs::RegisterReflection(); } +TVM_FFI_STATIC_INIT_BLOCK() { + AllClassNonMaximumSuppressionAttrs::RegisterReflection(); + GetValidCountsAttrs::RegisterReflection(); + NonMaximumSuppressionAttrs::RegisterReflection(); +} /* relax.vision.all_class_non_max_suppression */ @@ -110,5 +115,242 @@ TVM_REGISTER_OP("relax.vision.all_class_non_max_suppression") .set_attr("FInferStructInfo", InferStructInfoAllClassNMS) .set_attr("FPurity", Bool(true)); +/* relax.vision.get_valid_counts */ + +Expr get_valid_counts(Expr data, double score_threshold, int id_index, int score_index) { + auto attrs = tvm::ffi::make_object(); + attrs->score_threshold = score_threshold; + attrs->id_index = id_index; + attrs->score_index = score_index; + + static const Op& op = Op::Get("relax.vision.get_valid_counts"); + return Call(op, {std::move(data)}, Attrs(attrs), {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.vision.get_valid_counts", get_valid_counts); +} + +StructInfo InferStructInfoGetValidCounts(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "get_valid_counts expects 1 argument, got " << call->args.size()); + } + + const auto* data_sinfo = GetStructInfoAs(call->args[0]); + if (data_sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << "get_valid_counts expects input data to be a Tensor."); + } + if (data_sinfo->ndim != -1 && data_sinfo->ndim != 3) { + ctx->ReportFatal(Diagnostic::Error(call) + << "get_valid_counts expects 3-D input, got ndim " << data_sinfo->ndim); + } + + const auto* attrs = call->attrs.as(); + TVM_FFI_ICHECK(attrs != nullptr) << "Invalid get_valid_counts attrs"; + auto vdev = data_sinfo->vdevice; + const auto* data_shape = data_sinfo->shape.as(); + if (data_shape == nullptr) { + tvm::ffi::Array fields = { + TensorStructInfo(DataType::Int(32), /*ndim=*/1, vdev), + TensorStructInfo(data_sinfo->dtype, /*ndim=*/3, vdev), + TensorStructInfo(DataType::Int(32), /*ndim=*/2, vdev)}; + return TupleStructInfo(fields); + } + + auto batch = data_shape->values[0]; + auto num_anchors = data_shape->values[1]; + auto elem_length = data_shape->values[2]; + const auto* elem_length_imm = elem_length.as(); + if (elem_length_imm != nullptr) { + if (attrs->score_index < 0 || attrs->score_index >= elem_length_imm->value) { + ctx->ReportFatal(Diagnostic::Error(call) + << "get_valid_counts expects score_index to be in range [0, " + << elem_length_imm->value << "), but got " << attrs->score_index); + } + if (attrs->id_index < -1 || attrs->id_index >= elem_length_imm->value) { + ctx->ReportFatal(Diagnostic::Error(call) + << "get_valid_counts expects id_index to be in range [-1, " + << elem_length_imm->value << "), but got " << attrs->id_index); + } + } + + tvm::ffi::Array fields = { + TensorStructInfo(ShapeExpr({batch}), DataType::Int(32), vdev), + TensorStructInfo(ShapeExpr({batch, num_anchors, elem_length}), data_sinfo->dtype, vdev), + TensorStructInfo(ShapeExpr({batch, num_anchors}), DataType::Int(32), vdev)}; + return TupleStructInfo(fields); +} + +TVM_REGISTER_OP("relax.vision.get_valid_counts") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", + "Input data, 3-D tensor [batch_size, num_anchors, elem_length].") + .set_attr("FInferStructInfo", InferStructInfoGetValidCounts) + .set_attr("FPurity", Bool(true)); + +/* relax.vision.non_max_suppression */ + +Expr non_max_suppression(Expr data, Expr valid_count, Expr indices, int max_output_size, + double iou_threshold, bool force_suppress, int top_k, int coord_start, + int score_index, int id_index, bool return_indices, + bool invalid_to_bottom) { + auto attrs = tvm::ffi::make_object(); + attrs->max_output_size = max_output_size; + attrs->iou_threshold = iou_threshold; + attrs->force_suppress = force_suppress; + attrs->top_k = top_k; + attrs->coord_start = coord_start; + attrs->score_index = score_index; + attrs->id_index = id_index; + attrs->return_indices = return_indices; + attrs->invalid_to_bottom = invalid_to_bottom; + + static const Op& op = Op::Get("relax.vision.non_max_suppression"); + return Call(op, {std::move(data), std::move(valid_count), std::move(indices)}, Attrs(attrs), {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.vision.non_max_suppression", non_max_suppression); +} + +StructInfo InferStructInfoNMS(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 3) { + ctx->ReportFatal(Diagnostic::Error(call) + << "non_max_suppression expects 3 arguments, got " << call->args.size()); + } + + const auto* data_sinfo = GetStructInfoAs(call->args[0]); + const auto* valid_count_sinfo = GetStructInfoAs(call->args[1]); + const auto* indices_sinfo = GetStructInfoAs(call->args[2]); + if (data_sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << "non_max_suppression expects input data to be a Tensor."); + } + if (valid_count_sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << "non_max_suppression expects valid_count to be a Tensor."); + } + if (indices_sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << "non_max_suppression expects indices to be a Tensor."); + } + if (data_sinfo->ndim != -1 && data_sinfo->ndim != 3) { + ctx->ReportFatal(Diagnostic::Error(call) + << "non_max_suppression expects 3-D input, got ndim " << data_sinfo->ndim); + } + if (valid_count_sinfo->ndim != -1 && valid_count_sinfo->ndim != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "non_max_suppression expects valid_count to be 1-D, got ndim " + << valid_count_sinfo->ndim); + } + if (indices_sinfo->ndim != -1 && indices_sinfo->ndim != 2) { + ctx->ReportFatal(Diagnostic::Error(call) + << "non_max_suppression expects indices to be 2-D, got ndim " + << indices_sinfo->ndim); + } + if (!valid_count_sinfo->IsUnknownDtype() && valid_count_sinfo->dtype != DataType::Int(32)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "non_max_suppression expects valid_count to have dtype int32, got " + << valid_count_sinfo->dtype); + } + if (!indices_sinfo->IsUnknownDtype() && indices_sinfo->dtype != DataType::Int(32)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "non_max_suppression expects indices to have dtype int32, got " + << indices_sinfo->dtype); + } + + const auto* data_shape = data_sinfo->shape.as(); + const auto* valid_count_shape = valid_count_sinfo->shape.as(); + const auto* indices_shape = indices_sinfo->shape.as(); + if (data_shape != nullptr) { + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + PrimExpr batch = data_shape->values[0]; + PrimExpr num_anchors = data_shape->values[1]; + if (valid_count_shape != nullptr && + !analyzer->CanProveEqual(valid_count_shape->values[0], batch)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "non_max_suppression expects valid_count to have shape [batch_size]. " + "However, the given data tensor has batch size `" + << batch << "` and the given valid_count tensor has shape " + << valid_count_sinfo->shape); + } + if (indices_shape != nullptr) { + if (!analyzer->CanProveEqual(indices_shape->values[0], batch) || + !analyzer->CanProveEqual(indices_shape->values[1], num_anchors)) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "non_max_suppression expects indices to have shape [batch_size, num_anchors]. " + "However, the given data tensor has shape " + << data_sinfo->shape << " and the given indices tensor has shape " + << indices_sinfo->shape); + } + } + } + + const auto* attrs = call->attrs.as(); + TVM_FFI_ICHECK(attrs != nullptr) << "Invalid non_max_suppression attrs"; + auto vdev = data_sinfo->vdevice; + if (data_shape != nullptr) { + const auto* elem_length_imm = data_shape->values[2].as(); + if (elem_length_imm != nullptr) { + int64_t elem_length = elem_length_imm->value; + if (attrs->score_index < 0 || attrs->score_index >= elem_length) { + ctx->ReportFatal(Diagnostic::Error(call) + << "non_max_suppression expects score_index to be in range [0, " + << elem_length << "), but got " << attrs->score_index); + } + if (attrs->coord_start < 0 || attrs->coord_start + 3 >= elem_length) { + ctx->ReportFatal(Diagnostic::Error(call) + << "non_max_suppression expects coord_start to reference four " + "consecutive box coordinates within elem_length " + << elem_length << ", but got " << attrs->coord_start); + } + if (attrs->id_index < -1 || attrs->id_index >= elem_length) { + ctx->ReportFatal(Diagnostic::Error(call) + << "non_max_suppression expects id_index to be in range [-1, " + << elem_length << "), but got " << attrs->id_index); + } + } + } + + if (attrs->return_indices) { + // Returns (box_indices[batch, num_anchors], valid_box_count[batch, 1]) + if (data_shape == nullptr) { + tvm::ffi::Array fields = { + TensorStructInfo(DataType::Int(32), /*ndim=*/2, vdev), + TensorStructInfo(DataType::Int(32), /*ndim=*/2, vdev)}; + return TupleStructInfo(fields); + } + auto batch = data_shape->values[0]; + auto num_anchors = data_shape->values[1]; + tvm::ffi::Array fields = { + TensorStructInfo(ShapeExpr({batch, num_anchors}), DataType::Int(32), vdev), + TensorStructInfo(ShapeExpr({batch, IntImm(DataType::Int(64), 1)}), DataType::Int(32), + vdev)}; + return TupleStructInfo(fields); + } + + // Returns modified data tensor with the same shape as input. + if (const auto* data_shape = data_sinfo->shape.as()) { + return TensorStructInfo(ffi::GetRef(data_shape), data_sinfo->dtype, vdev); + } + return TensorStructInfo(data_sinfo->dtype, /*ndim=*/3, vdev); +} + +TVM_REGISTER_OP("relax.vision.non_max_suppression") + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", + "Input data, 3-D tensor [batch_size, num_anchors, elem_length].") + .add_argument("valid_count", "Tensor", "1-D tensor for valid number of boxes.") + .add_argument("indices", "Tensor", "2-D tensor with shape [batch_size, num_anchors].") + .set_attr("FInferStructInfo", InferStructInfoNMS) + .set_attr("FPurity", Bool(true)); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/vision/nms.h b/src/relax/op/vision/nms.h index c86bf98c94d5..3fbd2609e289 100644 --- a/src/relax/op/vision/nms.h +++ b/src/relax/op/vision/nms.h @@ -38,6 +38,15 @@ Expr all_class_non_max_suppression(Expr boxes, Expr scores, Expr max_output_boxe Expr iou_threshold, Expr score_threshold, ffi::String output_format); +/*! \brief Get valid count of bounding boxes given a score threshold. */ +Expr get_valid_counts(Expr data, double score_threshold, int id_index, int score_index); + +/*! \brief Non-maximum suppression for object detection. */ +Expr non_max_suppression(Expr data, Expr valid_count, Expr indices, int max_output_size, + double iou_threshold, bool force_suppress, int top_k, int coord_start, + int score_index, int id_index, bool return_indices, + bool invalid_to_bottom); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_op_vision.py b/tests/python/relax/test_op_vision.py index cded9f5f29e5..6d04a796ca9b 100644 --- a/tests/python/relax/test_op_vision.py +++ b/tests/python/relax/test_op_vision.py @@ -20,6 +20,7 @@ import tvm import tvm.testing +import tvm.topi.testing from tvm import TVMError, relax, tirx from tvm.ir import Op from tvm.relax.transform import LegalizeOps @@ -31,6 +32,23 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) +def _assert_relax_op_legalized(mod: tvm.IRModule, op_name: str) -> None: + seen_call_tir = False + seen_original_op = False + + def _visit(expr): + nonlocal seen_call_tir, seen_original_op + if isinstance(expr, relax.Call) and isinstance(expr.op, tvm.ir.Op): + if expr.op.name == "relax.call_tir": + seen_call_tir = True + if expr.op.name == op_name: + seen_original_op = True + + relax.analysis.post_order_visit(mod["main"].body, _visit) + assert seen_call_tir + assert not seen_original_op + + def test_roi_align_op_correctness(): x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) rois = relax.Var("rois", R.Tensor((4, 5), "float32")) @@ -198,6 +216,840 @@ def main( ) +def test_get_valid_counts_op_correctness(): + data = relax.Var("data", R.Tensor((2, 10, 6), "float32")) + assert relax.op.vision.get_valid_counts(data, 0.5).op == Op.get("relax.vision.get_valid_counts") + + +def test_get_valid_counts_infer_struct_info(): + bb = relax.BlockBuilder() + data = relax.Var("data", R.Tensor((2, 10, 6), "float32")) + _check_inference( + bb, + relax.op.vision.get_valid_counts(data, score_threshold=0.5, id_index=0, score_index=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2,), "int32"), + relax.TensorStructInfo((2, 10, 6), "float32"), + relax.TensorStructInfo((2, 10), "int32"), + ] + ), + ) + + +def test_get_valid_counts_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + n = tirx.Var("n", "int64") + m = tirx.Var("m", "int64") + k = tirx.Var("k", "int64") + data = relax.Var("data", R.Tensor((n, m, k), "float32")) + _check_inference( + bb, + relax.op.vision.get_valid_counts(data, score_threshold=0.0), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((n,), "int32"), + relax.TensorStructInfo((n, m, k), "float32"), + relax.TensorStructInfo((n, m), "int32"), + ] + ), + ) + + +def test_get_valid_counts_wrong_ndim(): + bb = relax.BlockBuilder() + data = relax.Var("data", R.Tensor((10, 6), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.get_valid_counts(data)) + + +def test_get_valid_counts_invalid_indices(): + bb = relax.BlockBuilder() + data = relax.Var("data", R.Tensor((2, 10, 6), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.get_valid_counts(data, score_index=6)) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.get_valid_counts(data, id_index=6)) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.get_valid_counts(data, id_index=-2)) + + +def test_nms_op_correctness(): + data = relax.Var("data", R.Tensor((2, 10, 6), "float32")) + valid_count = relax.Var("valid_count", R.Tensor((2,), "int32")) + indices = relax.Var("indices", R.Tensor((2, 10), "int32")) + assert relax.op.vision.non_max_suppression( + data, valid_count, indices + ).op == Op.get("relax.vision.non_max_suppression") + + +def test_nms_infer_struct_info_return_indices(): + bb = relax.BlockBuilder() + data = relax.Var("data", R.Tensor((2, 10, 6), "float32")) + valid_count = relax.Var("valid_count", R.Tensor((2,), "int32")) + indices = relax.Var("indices", R.Tensor((2, 10), "int32")) + _check_inference( + bb, + relax.op.vision.non_max_suppression( + data, valid_count, indices, return_indices=True + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 10), "int32"), + relax.TensorStructInfo((2, 1), "int32"), + ] + ), + ) + + +def test_nms_infer_struct_info_return_data(): + bb = relax.BlockBuilder() + data = relax.Var("data", R.Tensor((2, 10, 6), "float32")) + valid_count = relax.Var("valid_count", R.Tensor((2,), "int32")) + indices = relax.Var("indices", R.Tensor((2, 10), "int32")) + _check_inference( + bb, + relax.op.vision.non_max_suppression( + data, valid_count, indices, return_indices=False + ), + relax.TensorStructInfo((2, 10, 6), "float32"), + ) + + +def test_nms_infer_struct_info_return_data_shape_var(): + bb = relax.BlockBuilder() + batch_size = tirx.Var("batch_size", "int64") + num_anchors = tirx.Var("num_anchors", "int64") + elem_length = tirx.Var("elem_length", "int64") + data = relax.Var("data", R.Tensor((batch_size, num_anchors, elem_length), "float32")) + valid_count = relax.Var("valid_count", R.Tensor((batch_size,), "int32")) + indices = relax.Var("indices", R.Tensor((batch_size, num_anchors), "int32")) + _check_inference( + bb, + relax.op.vision.non_max_suppression( + data, valid_count, indices, return_indices=False + ), + relax.TensorStructInfo((batch_size, num_anchors, elem_length), "float32"), + ) + + +def test_nms_wrong_ndim(): + bb = relax.BlockBuilder() + data = relax.Var("data", R.Tensor((10, 6), "float32")) + valid_count = relax.Var("valid_count", R.Tensor((2,), "int32")) + indices = relax.Var("indices", R.Tensor((2, 10), "int32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.non_max_suppression(data, valid_count, indices)) + + +def test_nms_wrong_valid_count_ndim(): + bb = relax.BlockBuilder() + data = relax.Var("data", R.Tensor((2, 10, 6), "float32")) + valid_count = relax.Var("valid_count", R.Tensor((2, 1), "int32")) + indices = relax.Var("indices", R.Tensor((2, 10), "int32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.non_max_suppression(data, valid_count, indices)) + + +def test_nms_wrong_indices_ndim(): + bb = relax.BlockBuilder() + data = relax.Var("data", R.Tensor((2, 10, 6), "float32")) + valid_count = relax.Var("valid_count", R.Tensor((2,), "int32")) + indices = relax.Var("indices", R.Tensor((20,), "int32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.non_max_suppression(data, valid_count, indices)) + + +def test_nms_wrong_aux_input_dtype(): + bb = relax.BlockBuilder() + data = relax.Var("data", R.Tensor((2, 10, 6), "float32")) + valid_count_i64 = relax.Var("valid_count_i64", R.Tensor((2,), "int64")) + valid_count_i32 = relax.Var("valid_count_i32", R.Tensor((2,), "int32")) + indices_i64 = relax.Var("indices_i64", R.Tensor((2, 10), "int64")) + indices_i32 = relax.Var("indices_i32", R.Tensor((2, 10), "int32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.non_max_suppression(data, valid_count_i64, indices_i32)) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.non_max_suppression(data, valid_count_i32, indices_i64)) + + +def test_nms_wrong_aux_input_shape(): + bb = relax.BlockBuilder() + data = relax.Var("data", R.Tensor((2, 10, 6), "float32")) + valid_count_bad_batch = relax.Var("valid_count_bad_batch", R.Tensor((3,), "int32")) + valid_count = relax.Var("valid_count", R.Tensor((2,), "int32")) + indices_bad_batch = relax.Var("indices_bad_batch", R.Tensor((3, 10), "int32")) + indices_bad_anchors = relax.Var("indices_bad_anchors", R.Tensor((2, 9), "int32")) + with pytest.raises(TVMError): + bb.normalize( + relax.op.vision.non_max_suppression( + data, valid_count_bad_batch, indices_bad_anchors + ) + ) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.non_max_suppression(data, valid_count, indices_bad_batch)) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.non_max_suppression(data, valid_count, indices_bad_anchors)) + + +def test_nms_invalid_indices(): + bb = relax.BlockBuilder() + data = relax.Var("data", R.Tensor((2, 10, 6), "float32")) + valid_count = relax.Var("valid_count", R.Tensor((2,), "int32")) + indices = relax.Var("indices", R.Tensor((2, 10), "int32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.non_max_suppression(data, valid_count, indices, score_index=6)) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.non_max_suppression(data, valid_count, indices, id_index=6)) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.non_max_suppression(data, valid_count, indices, id_index=-2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.non_max_suppression(data, valid_count, indices, coord_start=3)) + + +def test_get_valid_counts_legalize(): + @tvm.script.ir_module + class GVC: + @R.function + def main( + data: R.Tensor((1, 5, 6), "float32"), + ) -> R.Tuple( + R.Tensor((1,), "int32"), + R.Tensor((1, 5, 6), "float32"), + R.Tensor((1, 5), "int32"), + ): + gv = R.vision.get_valid_counts(data, score_threshold=0.5, id_index=0, score_index=1) + return gv + + mod = LegalizeOps()(GVC) + _assert_relax_op_legalized(mod, "relax.vision.get_valid_counts") + tvm.ir.assert_structural_equal( + mod["main"].ret_struct_info, + relax.TupleStructInfo( + [ + relax.TensorStructInfo((1,), "int32"), + relax.TensorStructInfo((1, 5, 6), "float32"), + relax.TensorStructInfo((1, 5), "int32"), + ] + ), + ) + + +def test_nms_legalize(): + @tvm.script.ir_module + class NMS: + @R.function + def main( + data: R.Tensor((1, 5, 6), "float32"), + valid_count: R.Tensor((1,), "int32"), + indices: R.Tensor((1, 5), "int32"), + ) -> R.Tuple(R.Tensor((1, 5), "int32"), R.Tensor((1, 1), "int32")): + gv = R.vision.non_max_suppression( + data, + valid_count, + indices, + max_output_size=-1, + iou_threshold=0.5, + force_suppress=False, + top_k=-1, + coord_start=2, + score_index=1, + id_index=0, + return_indices=True, + invalid_to_bottom=False, + ) + return gv + + mod = LegalizeOps()(NMS) + _assert_relax_op_legalized(mod, "relax.vision.non_max_suppression") + tvm.ir.assert_structural_equal( + mod["main"].ret_struct_info, + relax.TupleStructInfo( + [ + relax.TensorStructInfo((1, 5), "int32"), + relax.TensorStructInfo((1, 1), "int32"), + ] + ), + ) + + +def test_nms_legalize_return_data(): + @tvm.script.ir_module + class NMS: + @R.function + def main( + data: R.Tensor((1, 5, 6), "float32"), + valid_count: R.Tensor((1,), "int32"), + indices: R.Tensor((1, 5), "int32"), + ) -> R.Tensor((1, 5, 6), "float32"): + gv = R.vision.non_max_suppression( + data, + valid_count, + indices, + max_output_size=-1, + iou_threshold=0.5, + force_suppress=False, + top_k=-1, + coord_start=2, + score_index=1, + id_index=0, + return_indices=False, + invalid_to_bottom=True, + ) + return gv + + mod = LegalizeOps()(NMS) + _assert_relax_op_legalized(mod, "relax.vision.non_max_suppression") + tvm.ir.assert_structural_equal( + mod["main"].ret_struct_info, + relax.TensorStructInfo((1, 5, 6), "float32"), + ) + + +@tvm.testing.requires_llvm +def test_get_valid_counts_e2e(): + """Run get_valid_counts through legalization and compare with the numpy reference.""" + + @tvm.script.ir_module + class GVCModule: + @R.function + def main( + data: R.Tensor((2, 5, 6), "float32"), + ) -> R.Tuple( + R.Tensor((2,), "int32"), + R.Tensor((2, 5, 6), "float32"), + R.Tensor((2, 5), "int32"), + ): + return R.vision.get_valid_counts(data, score_threshold=0.5, id_index=0, score_index=1) + + data_np = np.array( + [ + [ + [0.0, 0.95, 0.0, 0.0, 1.0, 1.0], + [1.0, 0.30, 0.0, 0.0, 1.0, 1.0], + [-1.0, 0.90, 0.0, 0.0, 1.0, 1.0], + [2.0, 0.75, 2.0, 2.0, 3.0, 3.0], + [1.0, 0.10, 4.0, 4.0, 5.0, 5.0], + ], + [ + [0.0, 0.55, 0.0, 0.0, 1.0, 1.0], + [1.0, 0.80, 1.0, 1.0, 2.0, 2.0], + [2.0, 0.40, 2.0, 2.0, 3.0, 3.0], + [3.0, 0.60, 3.0, 3.0, 4.0, 4.0], + [-1.0, 0.95, 5.0, 5.0, 6.0, 6.0], + ], + ], + dtype="float32", + ) + ref_valid_count, ref_out_data, ref_out_indices = tvm.topi.testing.get_valid_counts_python( + data_np, score_threshold=0.5, id_index=0, score_index=1 + ) + + mod = LegalizeOps()(GVCModule) + exe = tvm.compile(mod, target="llvm") + vm = relax.VirtualMachine(exe, tvm.cpu()) + result = vm["main"](tvm.runtime.tensor(data_np, tvm.cpu())) + + tvm.testing.assert_allclose(result[0].numpy(), ref_valid_count) + tvm.testing.assert_allclose(result[1].numpy(), ref_out_data) + tvm.testing.assert_allclose(result[2].numpy(), ref_out_indices) + + +def _prepare_nms_inputs(raw_data: np.ndarray): + """Prepare classic NMS inputs with the numpy get_valid_counts reference.""" + + return tvm.topi.testing.get_valid_counts_python( + raw_data, score_threshold=0.5, id_index=0, score_index=1 + ) + + +def _run_nms_e2e( + data_np: np.ndarray, + valid_count_np: np.ndarray, + indices_np: np.ndarray, + *, + max_output_size: int = -1, + iou_threshold: float = 0.5, + force_suppress: bool = False, + top_k: int = -1, + coord_start: int = 2, + score_index: int = 1, + id_index: int = 0, + return_indices: bool = True, + invalid_to_bottom: bool = False, +): + """Run classic NMS through legalization and VM execution.""" + + data_shape = tuple(int(dim) for dim in data_np.shape) + valid_count_shape = tuple(int(dim) for dim in valid_count_np.shape) + indices_shape = tuple(int(dim) for dim in indices_np.shape) + data = relax.Var("data", relax.TensorStructInfo(data_shape, "float32")) + valid_count = relax.Var("valid_count", relax.TensorStructInfo(valid_count_shape, "int32")) + indices = relax.Var("indices", relax.TensorStructInfo(indices_shape, "int32")) + + bb = relax.BlockBuilder() + with bb.function("main", (data, valid_count, indices)): + result = bb.emit( + relax.op.vision.non_max_suppression( + data, + valid_count, + indices, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + force_suppress=force_suppress, + top_k=top_k, + coord_start=coord_start, + score_index=score_index, + id_index=id_index, + return_indices=return_indices, + invalid_to_bottom=invalid_to_bottom, + ) + ) + bb.emit_func_output(result) + + mod = LegalizeOps()(bb.get()) + exe = tvm.compile(mod, target="llvm") + vm = relax.VirtualMachine(exe, tvm.cpu()) + return vm["main"]( + tvm.runtime.tensor(data_np, tvm.cpu()), + tvm.runtime.tensor(valid_count_np, tvm.cpu()), + tvm.runtime.tensor(indices_np, tvm.cpu()), + ) + + +@tvm.testing.requires_llvm +def test_nms_e2e_return_indices(): + """Run classic NMS through legalization and compare with the numpy reference.""" + + raw_data = np.array( + [ + [ + [0.0, 0.95, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.90, 0.05, 0.05, 1.05, 1.05], + [1.0, 0.85, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.60, 2.0, 2.0, 3.0, 3.0], + [-1.0, 0.99, 0.0, 0.0, 1.0, 1.0], + ] + ], + dtype="float32", + ) + valid_count_np, filtered_data_np, filtered_indices_np = _prepare_nms_inputs(raw_data) + ref_indices, ref_valid_box_count = tvm.topi.testing.non_max_suppression_python( + filtered_data_np, + valid_count_np, + filtered_indices_np, + max_output_size=-1, + iou_threshold=0.5, + force_suppress=False, + top_k=-1, + coord_start=2, + score_index=1, + id_index=0, + return_indices=True, + invalid_to_bottom=False, + ) + result = _run_nms_e2e( + filtered_data_np, + valid_count_np, + filtered_indices_np, + return_indices=True, + invalid_to_bottom=False, + ) + + tvm.testing.assert_allclose(result[0].numpy(), ref_indices) + tvm.testing.assert_allclose(result[1].numpy(), ref_valid_box_count) + + +@tvm.testing.requires_llvm +def test_nms_e2e_return_indices_with_invalid_to_bottom(): + """Validate that invalid_to_bottom is a no-op when returning indices.""" + + raw_data = np.array( + [ + [ + [0.0, 0.95, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.90, 0.05, 0.05, 1.05, 1.05], + [1.0, 0.85, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.60, 2.0, 2.0, 3.0, 3.0], + [-1.0, 0.99, 0.0, 0.0, 1.0, 1.0], + ] + ], + dtype="float32", + ) + valid_count_np, filtered_data_np, filtered_indices_np = _prepare_nms_inputs(raw_data) + ref_indices, ref_valid_box_count = tvm.topi.testing.non_max_suppression_python( + filtered_data_np, + valid_count_np, + filtered_indices_np, + max_output_size=-1, + iou_threshold=0.5, + force_suppress=False, + top_k=-1, + coord_start=2, + score_index=1, + id_index=0, + return_indices=True, + invalid_to_bottom=False, + ) + result = _run_nms_e2e( + filtered_data_np, + valid_count_np, + filtered_indices_np, + return_indices=True, + invalid_to_bottom=True, + ) + + tvm.testing.assert_allclose(result[0].numpy(), ref_indices) + tvm.testing.assert_allclose(result[1].numpy(), ref_valid_box_count) + + +@tvm.testing.requires_llvm +def test_nms_e2e_top_k(): + """Validate that classic NMS honors top_k before suppression.""" + + raw_data = np.array( + [ + [ + [-1.0, 0.99, 9.0, 9.0, 10.0, 10.0], + [0.0, 0.97, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.96, 2.0, 2.0, 3.0, 3.0], + [0.0, 0.95, 4.0, 4.0, 5.0, 5.0], + [1.0, 0.94, 6.0, 6.0, 7.0, 7.0], + [0.0, 0.20, 8.0, 8.0, 9.0, 9.0], + ] + ], + dtype="float32", + ) + valid_count_np, filtered_data_np, filtered_indices_np = _prepare_nms_inputs(raw_data) + ref_indices, ref_valid_box_count = tvm.topi.testing.non_max_suppression_python( + filtered_data_np, + valid_count_np, + filtered_indices_np, + max_output_size=-1, + iou_threshold=0.5, + force_suppress=False, + top_k=2, + coord_start=2, + score_index=1, + id_index=0, + return_indices=True, + invalid_to_bottom=False, + ) + result = _run_nms_e2e( + filtered_data_np, + valid_count_np, + filtered_indices_np, + top_k=2, + return_indices=True, + invalid_to_bottom=False, + ) + + tvm.testing.assert_allclose(result[0].numpy(), ref_indices) + tvm.testing.assert_allclose(result[1].numpy(), ref_valid_box_count) + np.testing.assert_array_equal(ref_indices, np.array([[1, 2, -1, -1, -1, -1]], dtype="int32")) + np.testing.assert_array_equal(ref_valid_box_count, np.array([[2]], dtype="int32")) + + +@tvm.testing.requires_llvm +def test_nms_e2e_force_suppress(): + """Validate that force_suppress ignores class ids when suppressing overlaps.""" + + raw_data = np.array( + [ + [ + [0.0, 0.95, 0.0, 0.0, 1.0, 1.0], + [1.0, 0.90, 0.05, 0.05, 1.05, 1.05], + [1.0, 0.80, 2.0, 2.0, 3.0, 3.0], + [-1.0, 0.99, 8.0, 8.0, 9.0, 9.0], + ] + ], + dtype="float32", + ) + valid_count_np, filtered_data_np, filtered_indices_np = _prepare_nms_inputs(raw_data) + ref_indices, ref_valid_box_count = tvm.topi.testing.non_max_suppression_python( + filtered_data_np, + valid_count_np, + filtered_indices_np, + max_output_size=-1, + iou_threshold=0.5, + force_suppress=True, + top_k=-1, + coord_start=2, + score_index=1, + id_index=0, + return_indices=True, + invalid_to_bottom=False, + ) + result = _run_nms_e2e( + filtered_data_np, + valid_count_np, + filtered_indices_np, + force_suppress=True, + return_indices=True, + invalid_to_bottom=False, + ) + + tvm.testing.assert_allclose(result[0].numpy(), ref_indices) + tvm.testing.assert_allclose(result[1].numpy(), ref_valid_box_count) + np.testing.assert_array_equal(ref_indices, np.array([[0, 2, -1, -1]], dtype="int32")) + np.testing.assert_array_equal(ref_valid_box_count, np.array([[2]], dtype="int32")) + + +@tvm.testing.requires_llvm +def test_nms_e2e_max_output_size(): + """Validate that max_output_size truncates the kept boxes after score sorting.""" + + raw_data = np.array( + [ + [ + [0.0, 0.97, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.95, 2.0, 2.0, 3.0, 3.0], + [0.0, 0.93, 4.0, 4.0, 5.0, 5.0], + [0.0, 0.91, 6.0, 6.0, 7.0, 7.0], + ] + ], + dtype="float32", + ) + valid_count_np, filtered_data_np, filtered_indices_np = _prepare_nms_inputs(raw_data) + ref_indices, ref_valid_box_count = tvm.topi.testing.non_max_suppression_python( + filtered_data_np, + valid_count_np, + filtered_indices_np, + max_output_size=2, + iou_threshold=1, + force_suppress=False, + top_k=-1, + coord_start=2, + score_index=1, + id_index=0, + return_indices=True, + invalid_to_bottom=False, + ) + result = _run_nms_e2e( + filtered_data_np, + valid_count_np, + filtered_indices_np, + max_output_size=2, + iou_threshold=1, + return_indices=True, + invalid_to_bottom=False, + ) + + tvm.testing.assert_allclose(result[0].numpy(), ref_indices) + tvm.testing.assert_allclose(result[1].numpy(), ref_valid_box_count) + np.testing.assert_array_equal(ref_indices, np.array([[0, 1, -1, -1]], dtype="int32")) + np.testing.assert_array_equal(ref_valid_box_count, np.array([[2]], dtype="int32")) + + +@tvm.testing.requires_llvm +def test_nms_e2e_multi_batch(): + """Validate that classic NMS processes each batch independently.""" + + raw_data = np.array( + [ + [ + [0.0, 0.95, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.90, 0.05, 0.05, 1.05, 1.05], + [1.0, 0.80, 2.0, 2.0, 3.0, 3.0], + [-1.0, 0.99, 8.0, 8.0, 9.0, 9.0], + ], + [ + [1.0, 0.96, 0.0, 0.0, 1.0, 1.0], + [2.0, 0.94, 0.04, 0.04, 1.04, 1.04], + [2.0, 0.88, 3.0, 3.0, 4.0, 4.0], + [2.0, 0.30, 6.0, 6.0, 7.0, 7.0], + ], + ], + dtype="float32", + ) + valid_count_np, filtered_data_np, filtered_indices_np = _prepare_nms_inputs(raw_data) + ref_indices, ref_valid_box_count = tvm.topi.testing.non_max_suppression_python( + filtered_data_np, + valid_count_np, + filtered_indices_np, + max_output_size=-1, + iou_threshold=0.5, + force_suppress=False, + top_k=-1, + coord_start=2, + score_index=1, + id_index=0, + return_indices=True, + invalid_to_bottom=False, + ) + result = _run_nms_e2e( + filtered_data_np, + valid_count_np, + filtered_indices_np, + return_indices=True, + invalid_to_bottom=False, + ) + + tvm.testing.assert_allclose(result[0].numpy(), ref_indices) + tvm.testing.assert_allclose(result[1].numpy(), ref_valid_box_count) + np.testing.assert_array_equal( + ref_indices, + np.array([[0, 2, -1, -1], [0, 1, 2, -1]], dtype="int32"), + ) + np.testing.assert_array_equal(ref_valid_box_count, np.array([[2], [3]], dtype="int32")) + + +@tvm.testing.requires_llvm +def test_nms_e2e_invalid_to_bottom(): + """Validate that invalid_to_bottom compacts only boxes that remain valid after NMS.""" + + raw_data = np.array( + [ + [ + [0.0, 0.95, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.90, 0.05, 0.05, 1.05, 1.05], + [1.0, 0.85, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.60, 2.0, 2.0, 3.0, 3.0], + [-1.0, 0.99, 8.0, 8.0, 9.0, 9.0], + ] + ], + dtype="float32", + ) + valid_count_np, filtered_data_np, filtered_indices_np = _prepare_nms_inputs(raw_data) + ref_out_data = tvm.topi.testing.non_max_suppression_python( + filtered_data_np, + valid_count_np, + filtered_indices_np, + max_output_size=-1, + iou_threshold=0.5, + force_suppress=False, + top_k=-1, + coord_start=2, + score_index=1, + id_index=0, + return_indices=False, + invalid_to_bottom=True, + ) + result = _run_nms_e2e( + filtered_data_np, + valid_count_np, + filtered_indices_np, + return_indices=False, + invalid_to_bottom=True, + ) + expected_out_data = np.array( + [ + [ + [0.0, 0.95, 0.0, 0.0, 1.0, 1.0], + [1.0, 0.85, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.60, 2.0, 2.0, 3.0, 3.0], + [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0], + [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0], + ] + ], + dtype="float32", + ) + + tvm.testing.assert_allclose(result.numpy(), ref_out_data) + tvm.testing.assert_allclose(result.numpy(), expected_out_data) + + +@tvm.testing.requires_llvm +def test_nms_e2e_return_data_without_compaction(): + """Validate the return_indices=False path when invalid boxes stay in-place.""" + + raw_data = np.array( + [ + [ + [0.0, 0.95, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.90, 0.05, 0.05, 1.05, 1.05], + [1.0, 0.85, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.60, 2.0, 2.0, 3.0, 3.0], + [-1.0, 0.99, 8.0, 8.0, 9.0, 9.0], + ] + ], + dtype="float32", + ) + valid_count_np, filtered_data_np, filtered_indices_np = _prepare_nms_inputs(raw_data) + ref_out_data = tvm.topi.testing.non_max_suppression_python( + filtered_data_np, + valid_count_np, + filtered_indices_np, + max_output_size=-1, + iou_threshold=0.5, + force_suppress=False, + top_k=-1, + coord_start=2, + score_index=1, + id_index=0, + return_indices=False, + invalid_to_bottom=False, + ) + result = _run_nms_e2e( + filtered_data_np, + valid_count_np, + filtered_indices_np, + return_indices=False, + invalid_to_bottom=False, + ) + expected_out_data = np.array( + [ + [ + [0.0, 0.95, 0.0, 0.0, 1.0, 1.0], + [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0], + [1.0, 0.85, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.60, 2.0, 2.0, 3.0, 3.0], + [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0], + ] + ], + dtype="float32", + ) + + tvm.testing.assert_allclose(result.numpy(), ref_out_data) + tvm.testing.assert_allclose(result.numpy(), expected_out_data) + + +@tvm.testing.requires_llvm +def test_nms_e2e_index_remap(): + """Validate that returned indices remap from filtered order back to original order.""" + + raw_data = np.array( + [ + [ + [-1.0, 0.99, 9.0, 9.0, 10.0, 10.0], + [0.0, 0.60, 4.0, 4.0, 5.0, 5.0], + [0.0, 0.10, 8.0, 8.0, 9.0, 9.0], + [0.0, 0.95, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.90, 0.05, 0.05, 1.05, 1.05], + [1.0, 0.80, 2.0, 2.0, 3.0, 3.0], + ] + ], + dtype="float32", + ) + valid_count_np, filtered_data_np, filtered_indices_np = _prepare_nms_inputs(raw_data) + ref_indices, ref_valid_box_count = tvm.topi.testing.non_max_suppression_python( + filtered_data_np, + valid_count_np, + filtered_indices_np, + max_output_size=-1, + iou_threshold=0.5, + force_suppress=False, + top_k=-1, + coord_start=2, + score_index=1, + id_index=0, + return_indices=True, + invalid_to_bottom=False, + ) + result = _run_nms_e2e( + filtered_data_np, + valid_count_np, + filtered_indices_np, + return_indices=True, + invalid_to_bottom=False, + ) + + tvm.testing.assert_allclose(result[0].numpy(), ref_indices) + tvm.testing.assert_allclose(result[1].numpy(), ref_valid_box_count) + np.testing.assert_array_equal(ref_indices, np.array([[3, 5, 1, -1, -1, -1]], dtype="int32")) + np.testing.assert_array_equal(ref_valid_box_count, np.array([[3]], dtype="int32")) + + def test_all_class_non_max_suppression_infer_struct_info(): bb = relax.BlockBuilder() batch_size, num_classes, num_boxes = 10, 8, 5 @@ -450,11 +1302,11 @@ def _softmax(x, axis): boxes = np.zeros((B, N, 4), dtype=np.float32) for b in range(B): for a in range(N): - l, t, r, br = anchor[0, a, :] - ay = (t + br) * 0.5 - ax = (l + r) * 0.5 - ah = br - t - aw = r - l + left, top, right, bottom = anchor[0, a, :] + ay = (top + bottom) * 0.5 + ax = (left + right) * 0.5 + ah = bottom - top + aw = right - left ex, ey, ew, eh = loc[b, a, :] ycenter = ey * vy * ah + ay xcenter = ex * vx * aw + ax diff --git a/tests/python/relax/test_tvmscript_parser_op_vision.py b/tests/python/relax/test_tvmscript_parser_op_vision.py index f053e3674493..370b68769e6e 100644 --- a/tests/python/relax/test_tvmscript_parser_op_vision.py +++ b/tests/python/relax/test_tvmscript_parser_op_vision.py @@ -75,6 +75,138 @@ def foo( _check(foo, bb.get()["foo"]) +def test_get_valid_counts(): + @R.function + def foo( + data: R.Tensor((10, 5, 6), "float32"), + ) -> R.Tuple( + R.Tensor((10,), "int32"), + R.Tensor((10, 5, 6), "float32"), + R.Tensor((10, 5), "int32"), + ): + gv: R.Tuple( + R.Tensor((10,), "int32"), + R.Tensor((10, 5, 6), "float32"), + R.Tensor((10, 5), "int32"), + ) = R.vision.get_valid_counts(data, score_threshold=0.5, id_index=0, score_index=1) + return gv + + data = relax.Var("data", R.Tensor((10, 5, 6), "float32")) + + bb = relax.BlockBuilder() + with bb.function("foo", [data]): + gv = bb.emit( + relax.op.vision.get_valid_counts( + data, score_threshold=0.5, id_index=0, score_index=1 + ) + ) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_non_max_suppression_return_indices(): + @R.function + def foo( + data: R.Tensor((2, 5, 6), "float32"), + valid_count: R.Tensor((2,), "int32"), + indices: R.Tensor((2, 5), "int32"), + ) -> R.Tuple(R.Tensor((2, 5), "int32"), R.Tensor((2, 1), "int32")): + gv: R.Tuple(R.Tensor((2, 5), "int32"), R.Tensor((2, 1), "int32")) = ( + R.vision.non_max_suppression( + data, + valid_count, + indices, + max_output_size=-1, + iou_threshold=0.5, + force_suppress=False, + top_k=3, + coord_start=2, + score_index=1, + id_index=0, + return_indices=True, + invalid_to_bottom=False, + ) + ) + return gv + + data = relax.Var("data", R.Tensor((2, 5, 6), "float32")) + valid_count = relax.Var("valid_count", R.Tensor((2,), "int32")) + indices = relax.Var("indices", R.Tensor((2, 5), "int32")) + + bb = relax.BlockBuilder() + with bb.function("foo", [data, valid_count, indices]): + gv = bb.emit( + relax.op.vision.non_max_suppression( + data, + valid_count, + indices, + max_output_size=-1, + iou_threshold=0.5, + force_suppress=False, + top_k=3, + coord_start=2, + score_index=1, + id_index=0, + return_indices=True, + invalid_to_bottom=False, + ) + ) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_non_max_suppression_return_data(): + @R.function + def foo( + data: R.Tensor((2, 5, 6), "float32"), + valid_count: R.Tensor((2,), "int32"), + indices: R.Tensor((2, 5), "int32"), + ) -> R.Tensor((2, 5, 6), "float32"): + gv: R.Tensor((2, 5, 6), "float32") = R.vision.non_max_suppression( + data, + valid_count, + indices, + max_output_size=-1, + iou_threshold=0.5, + force_suppress=False, + top_k=-1, + coord_start=2, + score_index=1, + id_index=0, + return_indices=False, + invalid_to_bottom=True, + ) + return gv + + data = relax.Var("data", R.Tensor((2, 5, 6), "float32")) + valid_count = relax.Var("valid_count", R.Tensor((2,), "int32")) + indices = relax.Var("indices", R.Tensor((2, 5), "int32")) + + bb = relax.BlockBuilder() + with bb.function("foo", [data, valid_count, indices]): + gv = bb.emit( + relax.op.vision.non_max_suppression( + data, + valid_count, + indices, + max_output_size=-1, + iou_threshold=0.5, + force_suppress=False, + top_k=-1, + coord_start=2, + score_index=1, + id_index=0, + return_indices=False, + invalid_to_bottom=True, + ) + ) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + def test_multibox_transform_loc(): @R.function def foo(