Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions include/tvm/relax/attrs/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,66 @@ struct ROIAlignAttrs : public AttrsNodeReflAdapter<ROIAlignAttrs> {
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIAlignAttrs", ROIAlignAttrs, BaseAttrsNode);
}; // struct ROIAlignAttrs

/*! \brief Attributes used in GetValidCounts operator */
struct GetValidCountsAttrs : public AttrsNodeReflAdapter<GetValidCountsAttrs> {
double score_threshold;
int id_index;
int score_index;

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GetValidCountsAttrs>()
.def_ro("score_threshold", &GetValidCountsAttrs::score_threshold,
"Lower limit of score for valid bounding boxes.")
.def_ro("id_index", &GetValidCountsAttrs::id_index,
"Index of the class categories, -1 to disable.")
.def_ro("score_index", &GetValidCountsAttrs::score_index,
"Index of the scores/confidence of boxes.");
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GetValidCountsAttrs", GetValidCountsAttrs,
BaseAttrsNode);
}; // struct GetValidCountsAttrs

/*! \brief Attributes used in NonMaximumSuppression operator */
struct NonMaximumSuppressionAttrs
: public AttrsNodeReflAdapter<NonMaximumSuppressionAttrs> {
int max_output_size;
double iou_threshold;
bool force_suppress;
int top_k;
int coord_start;
int score_index;
int id_index;
bool return_indices;
bool invalid_to_bottom;

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<NonMaximumSuppressionAttrs>()
.def_ro("max_output_size", &NonMaximumSuppressionAttrs::max_output_size,
"Max number of output valid boxes, -1 for no limit.")
.def_ro("iou_threshold", &NonMaximumSuppressionAttrs::iou_threshold,
"Non-maximum suppression IoU threshold.")
.def_ro("force_suppress", &NonMaximumSuppressionAttrs::force_suppress,
"Whether to suppress all detections regardless of class_id.")
.def_ro("top_k", &NonMaximumSuppressionAttrs::top_k,
"Keep maximum top k detections before nms, -1 for no limit.")
.def_ro("coord_start", &NonMaximumSuppressionAttrs::coord_start,
"Start index of the consecutive 4 coordinates.")
.def_ro("score_index", &NonMaximumSuppressionAttrs::score_index,
"Index of the scores/confidence of boxes.")
.def_ro("id_index", &NonMaximumSuppressionAttrs::id_index,
"Index of the class categories, -1 to disable.")
.def_ro("return_indices", &NonMaximumSuppressionAttrs::return_indices,
"Whether to return box indices in input data.")
.def_ro("invalid_to_bottom", &NonMaximumSuppressionAttrs::invalid_to_bottom,
"Whether to move all valid bounding boxes to the top.");
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.NonMaximumSuppressionAttrs",
NonMaximumSuppressionAttrs, BaseAttrsNode);
}; // struct NonMaximumSuppressionAttrs


/*! \brief Attributes for multibox_transform_loc (SSD / TFLite-style box decode). */
struct MultiboxTransformLocAttrs : public AttrsNodeReflAdapter<MultiboxTransformLocAttrs> {
bool clip;
Expand Down
8 changes: 7 additions & 1 deletion python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,13 @@
tanh,
trunc,
)
from .vision import all_class_non_max_suppression, multibox_transform_loc, roi_align
from .vision import (
all_class_non_max_suppression,
get_valid_counts,
multibox_transform_loc,
non_max_suppression,
roi_align,
)


def _register_op_make():
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/relax/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,16 @@ class AllClassNonMaximumSuppressionAttrs(Attrs):
"""Attributes for vision.all_class_non_max_suppression"""


@tvm_ffi.register_object("relax.attrs.GetValidCountsAttrs")
class GetValidCountsAttrs(Attrs):
"""Attributes for vision.get_valid_counts"""


@tvm_ffi.register_object("relax.attrs.NonMaximumSuppressionAttrs")
class NonMaximumSuppressionAttrs(Attrs):
"""Attributes for vision.non_max_suppression"""


@tvm_ffi.register_object("relax.attrs.ROIAlignAttrs")
class ROIAlignAttrs(Attrs):
"""Attributes for vision.roi_align"""
Expand Down
114 changes: 112 additions & 2 deletions python/tvm/relax/op/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Non-maximum suppression operator"""
"""Non-maximum suppression operators."""

# from tvm import relax # Unused import
from . import _ffi_api


Expand Down Expand Up @@ -72,3 +71,114 @@ def all_class_non_max_suppression(
return _ffi_api.all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_format
)


def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):
"""Get valid count of bounding boxes given a score threshold.
Also moves valid boxes to the top of input data.

Parameters
----------
data : relax.Expr
3-D tensor with shape [batch_size, num_anchors, elem_length].

score_threshold : float, optional
Lower limit of score for valid bounding boxes.

id_index : int, optional
Index of the class categories. Set to ``-1`` to disable the class-id check.

score_index : int, optional
Index of the scores/confidence of boxes.

Returns
-------
out : relax.Expr
A tuple ``(valid_count, out_tensor, out_indices)`` where ``valid_count``
has shape ``[batch_size]``, ``out_tensor`` has shape
``[batch_size, num_anchors, elem_length]``, and ``out_indices`` has shape
``[batch_size, num_anchors]``.
"""
return _ffi_api.get_valid_counts(data, score_threshold, id_index, score_index)


def non_max_suppression(
data,
valid_count,
indices,
max_output_size=-1,
iou_threshold=0.5,
force_suppress=False,
top_k=-1,
coord_start=2,
score_index=1,
id_index=0,
return_indices=True,
invalid_to_bottom=False,
):
"""Non-maximum suppression operator for object detection.

Parameters
----------
data : relax.Expr
3-D tensor with shape [batch_size, num_anchors, elem_length].

valid_count : relax.Expr
1-D tensor for valid number of boxes.

indices : relax.Expr
2-D tensor with shape [batch_size, num_anchors].

max_output_size : int, optional
Max number of output valid boxes, -1 for no limit.

iou_threshold : float, optional
Non-maximum suppression IoU threshold.

force_suppress : bool, optional
Whether to suppress all detections regardless of class_id. When
``id_index`` is ``-1``, all valid boxes are treated as belonging to the
same class, so this flag has the same effect as ``True``.

top_k : int, optional
Keep maximum top k detections before nms, -1 for no limit.

coord_start : int, optional
Start index of the consecutive 4 coordinates.

score_index : int, optional
Index of the scores/confidence of boxes.

id_index : int, optional
Index of the class categories. Set to ``-1`` to suppress boxes across
all classes.

return_indices : bool, optional
Whether to return box indices in input data.

invalid_to_bottom : bool, optional
Whether to move valid bounding boxes to the top of the returned tensor.
This option only affects the ``return_indices=False`` path.

Returns
-------
out : relax.Expr
If ``return_indices`` is ``True``, returns
``(box_indices, valid_box_count)`` with shapes
``[batch_size, num_anchors]`` and ``[batch_size, 1]``.
Otherwise returns the modified data tensor.
"""
return _ffi_api.non_max_suppression(
data,
valid_count,
indices,
max_output_size,
iou_threshold,
force_suppress,
top_k,
coord_start,
score_index,
id_index,
return_indices,
invalid_to_bottom,
)
30 changes: 30 additions & 0 deletions python/tvm/relax/transform/legalize_ops/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,36 @@ def _roi_align(bb: BlockBuilder, call: Call) -> Expr:
)


@register_legalize("relax.vision.get_valid_counts")
def _get_valid_counts(block_builder: BlockBuilder, call: Call) -> Expr:
return block_builder.call_te(
topi.vision.get_valid_counts,
call.args[0],
score_threshold=call.attrs.score_threshold,
id_index=call.attrs.id_index,
score_index=call.attrs.score_index,
)


@register_legalize("relax.vision.non_max_suppression")
def _non_max_suppression(block_builder: BlockBuilder, call: Call) -> Expr:
return block_builder.call_te(
topi.vision.non_max_suppression,
call.args[0],
call.args[1],
call.args[2],
max_output_size=call.attrs.max_output_size,
iou_threshold=call.attrs.iou_threshold,
force_suppress=call.attrs.force_suppress,
top_k=call.attrs.top_k,
coord_start=call.attrs.coord_start,
score_index=call.attrs.score_index,
id_index=call.attrs.id_index,
return_indices=call.attrs.return_indices,
invalid_to_bottom=call.attrs.invalid_to_bottom,
)


@register_legalize("relax.vision.multibox_transform_loc")
def _multibox_transform_loc(bb: BlockBuilder, call: Call) -> Expr:
variances = tuple(float(x) for x in call.attrs.variances)
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@
from .l2_normalize_python import l2_normalize_python
from .gather_python import gather_python
from .gather_nd_python import gather_nd_python
from .get_valid_counts_python import get_valid_counts_python
from .strided_slice_python import strided_slice_python, strided_set_python
from .batch_matmul import batch_matmul
from .batch_norm import batch_norm
from .nms_python import non_max_suppression_python
from .slice_axis_python import slice_axis_python
from .sequence_mask_python import sequence_mask
from .poolnd_python import poolnd_python
Expand Down
68 changes: 68 additions & 0 deletions python/tvm/topi/testing/get_valid_counts_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Numpy reference implementation for get_valid_counts."""
import numpy as np


def get_valid_counts_python(data, score_threshold=0, id_index=0, score_index=1):
"""Numpy reference for get_valid_counts.

Parameters
----------
data : numpy.ndarray
3-D array with shape [batch_size, num_anchors, elem_length].

score_threshold : float
Lower limit of score for valid bounding boxes.

id_index : int
Index of the class categories, -1 to disable.

score_index : int
Index of the scores/confidence of boxes.

Returns
-------
valid_count : numpy.ndarray
1-D array, shape [batch_size].

out_tensor : numpy.ndarray
Rearranged data, shape [batch_size, num_anchors, elem_length].

out_indices : numpy.ndarray
Indices mapping, shape [batch_size, num_anchors].
"""
batch_size, num_anchors, box_data_length = data.shape
valid_count = np.zeros(batch_size, dtype="int32")
out_tensor = np.full_like(data, -1.0)
out_indices = np.full((batch_size, num_anchors), -1, dtype="int32")

for i in range(batch_size):
cnt = 0
for j in range(num_anchors):
score = data[i, j, score_index]
if id_index < 0:
is_valid = score > score_threshold
else:
is_valid = score > score_threshold and data[i, j, id_index] >= 0
if is_valid:
out_tensor[i, cnt, :] = data[i, j, :]
out_indices[i, cnt] = j
cnt += 1
valid_count[i] = cnt

return valid_count, out_tensor, out_indices
Loading
Loading