diff --git a/include/tvm/relax/attrs/vision.h b/include/tvm/relax/attrs/vision.h index 69ce458e7e35..8971127d76dc 100644 --- a/include/tvm/relax/attrs/vision.h +++ b/include/tvm/relax/attrs/vision.h @@ -73,6 +73,23 @@ struct ROIAlignAttrs : public AttrsNodeReflAdapter { TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIAlignAttrs", ROIAlignAttrs, BaseAttrsNode); }; // struct ROIAlignAttrs +/*! \brief Attributes used in ROIPool operator */ +struct ROIPoolAttrs : public AttrsNodeReflAdapter { + ffi::Array pooled_size; + double spatial_scale; + ffi::String layout; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("pooled_size", &ROIPoolAttrs::pooled_size, "Output size of roi pool.") + .def_ro("spatial_scale", &ROIPoolAttrs::spatial_scale, + "Ratio of input feature map height (or width) to raw image height (or width).") + .def_ro("layout", &ROIPoolAttrs::layout, "Dimension ordering of the input data."); + } + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIPoolAttrs", ROIPoolAttrs, BaseAttrsNode); +}; // struct ROIPoolAttrs + /*! \brief Attributes used in GetValidCounts operator */ struct GetValidCountsAttrs : public AttrsNodeReflAdapter { double score_threshold; @@ -132,7 +149,6 @@ struct NonMaximumSuppressionAttrs NonMaximumSuppressionAttrs, BaseAttrsNode); }; // struct NonMaximumSuppressionAttrs - /*! \brief Attributes for multibox_transform_loc (SSD / TFLite-style box decode). */ struct MultiboxTransformLocAttrs : public AttrsNodeReflAdapter { bool clip; diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index e56f975c6289..6f263b0c1742 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -2517,6 +2517,28 @@ def _impl_v16(cls, bb, inputs, attr, params): return cls._impl(bb, inputs, attr, params, b"half_pixel") +class MaxRoiPool(OnnxOpConverter): + """Converts an onnx MaxRoiPool node into an equivalent Relax expression.""" + + @classmethod + def _impl_v1(cls, bb, inputs, attr, params): + if len(inputs) != 2: + raise ValueError("MaxRoiPool expects exactly 2 inputs") + + pooled_shape = attr.get("pooled_shape") + if pooled_shape is None: + raise ValueError("MaxRoiPool requires pooled_shape attribute") + + spatial_scale = attr.get("spatial_scale", 1.0) + return relax.op.vision.roi_pool( + inputs[0], + inputs[1], + pooled_size=tuple(pooled_shape), + spatial_scale=spatial_scale, + layout="NCHW", + ) + + class Range(OnnxOpConverter): """Converts an onnx Range node into an equivalent Relax expression.""" @@ -4177,7 +4199,7 @@ def _get_convert_map(): "OneHot": OneHot, "Unique": Unique, "NonZero": NonZero, - # "MaxRoiPool": MaxRoiPool, + "MaxRoiPool": MaxRoiPool, "RoiAlign": RoiAlign, "NonMaxSuppression": NonMaxSuppression, "AllClassNMS": AllClassNMS, diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 0b8dc4e7de59..6f985ef36cac 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -163,6 +163,7 @@ multibox_transform_loc, non_max_suppression, roi_align, + roi_pool, ) diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index d85c439d3ae2..1a186a79e87b 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -261,6 +261,11 @@ class ROIAlignAttrs(Attrs): """Attributes for vision.roi_align""" +@tvm_ffi.register_object("relax.attrs.ROIPoolAttrs") +class ROIPoolAttrs(Attrs): + """Attributes for vision.roi_pool""" + + @tvm_ffi.register_object("relax.attrs.MultiboxTransformLocAttrs") class MultiboxTransformLocAttrs(Attrs): """Attributes for vision.multibox_transform_loc""" diff --git a/python/tvm/relax/op/vision/__init__.py b/python/tvm/relax/op/vision/__init__.py index 58266c5b2add..f99bbc95dd55 100644 --- a/python/tvm/relax/op/vision/__init__.py +++ b/python/tvm/relax/op/vision/__init__.py @@ -20,3 +20,4 @@ from .multibox_transform_loc import * from .nms import * from .roi_align import * +from .roi_pool import * diff --git a/python/tvm/relax/op/vision/roi_pool.py b/python/tvm/relax/op/vision/roi_pool.py new file mode 100644 index 000000000000..f8b7f114635a --- /dev/null +++ b/python/tvm/relax/op/vision/roi_pool.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""ROI Pool operator""" + +from ..base import Expr +from . import _ffi_api + + +def roi_pool( + data: Expr, + rois: Expr, + pooled_size: int | tuple[int, int] | list[int], + spatial_scale: float, + layout: str = "NCHW", +): + """ROI Pool operator. + + Parameters + ---------- + data : relax.Expr + 4-D input tensor. + + rois : relax.Expr + 2-D input tensor with shape `(num_roi, 5)` in + `[batch_idx, x1, y1, x2, y2]` format. + + pooled_size : Union[int, Tuple[int, int], List[int]] + Output pooled size. + + spatial_scale : float + Ratio of input feature map height (or width) to raw image height (or width). + + layout : str, optional + Layout of the input data. Currently only `NCHW` is supported. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(pooled_size, int): + pooled_size = (pooled_size, pooled_size) + return _ffi_api.roi_pool(data, rois, pooled_size, spatial_scale, layout) diff --git a/python/tvm/relax/transform/legalize_ops/vision.py b/python/tvm/relax/transform/legalize_ops/vision.py index ea0458bfcef0..7d8586ab5288 100644 --- a/python/tvm/relax/transform/legalize_ops/vision.py +++ b/python/tvm/relax/transform/legalize_ops/vision.py @@ -150,6 +150,18 @@ def _non_max_suppression(block_builder: BlockBuilder, call: Call) -> Expr: ) +@register_legalize("relax.vision.roi_pool") +def _roi_pool(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + topi.vision.roi_pool, + call.args[0], + call.args[1], + pooled_size=call.attrs.pooled_size, + spatial_scale=call.attrs.spatial_scale, + layout=call.attrs.layout, + ) + + @register_legalize("relax.vision.multibox_transform_loc") def _multibox_transform_loc(bb: BlockBuilder, call: Call) -> Expr: variances = tuple(float(x) for x in call.attrs.variances) diff --git a/python/tvm/runtime/support.py b/python/tvm/runtime/support.py index b0ac67176328..f6591b28717e 100644 --- a/python/tvm/runtime/support.py +++ b/python/tvm/runtime/support.py @@ -146,10 +146,17 @@ def method(*args, **kwargs): fields = metadata.get("fields", []) methods = metadata.get("methods", []) - class TVMDerivedObject(metadata["cls"]): # type: ignore + base_cls = metadata["cls"] + slots = [] + if getattr(base_cls, "__dictoffset__", 0) == 0: + slots.append("__dict__") + if getattr(base_cls, "__weakrefoffset__", 0) == 0: + slots.append("__weakref__") + + class TVMDerivedObject(base_cls): # type: ignore """The derived object to avoid cyclic dependency.""" - __slots__ = ("__dict__", "__weakref__",) + __slots__ = tuple(slots) _cls = cls _type = "TVMDerivedObject" diff --git a/python/tvm/s_tir/meta_schedule/utils.py b/python/tvm/s_tir/meta_schedule/utils.py index 2460a6cc265d..13442117117f 100644 --- a/python/tvm/s_tir/meta_schedule/utils.py +++ b/python/tvm/s_tir/meta_schedule/utils.py @@ -106,10 +106,17 @@ def method(*args, **kwargs): fields = metadata.get("fields", []) methods = metadata.get("methods", []) - class TVMDerivedObject(metadata["cls"]): # type: ignore + base_cls = metadata["cls"] + slots = [] + if getattr(base_cls, "__dictoffset__", 0) == 0: + slots.append("__dict__") + if getattr(base_cls, "__weakrefoffset__", 0) == 0: + slots.append("__weakref__") + + class TVMDerivedObject(base_cls): # type: ignore """The derived object to avoid cyclic dependency.""" - __slots__ = ("__dict__", "__weakref__",) + __slots__ = tuple(slots) _cls = cls _type = "TVMDerivedObject" diff --git a/python/tvm/topi/vision/__init__.py b/python/tvm/topi/vision/__init__.py index cb0467c98cd4..93074201f580 100644 --- a/python/tvm/topi/vision/__init__.py +++ b/python/tvm/topi/vision/__init__.py @@ -20,3 +20,4 @@ from .multibox_transform_loc import * from .nms import * from .roi_align import * +from .roi_pool import * diff --git a/python/tvm/topi/vision/roi_pool.py b/python/tvm/topi/vision/roi_pool.py new file mode 100644 index 000000000000..54a4aeba50be --- /dev/null +++ b/python/tvm/topi/vision/roi_pool.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""ROI Pool operator""" + +import tvm +from tvm import te + + +def roi_pool_nchw(data, rois, pooled_size, spatial_scale): + """ROI pool operator in NCHW layout.""" + _, channel, height, width = data.shape + num_roi, _ = rois.shape + + if isinstance(pooled_size, int): + pooled_size_h = pooled_size_w = pooled_size + else: + pooled_size_h, pooled_size_w = pooled_size + + zero = tvm.tirx.const(0.0, data.dtype) + roi_dtype = rois.dtype + + neg_inf = tvm.tirx.const(float("-inf"), data.dtype) + + def _bin_bounds(i, ph, pw): + roi = rois[i] + roi_start_w = te.round(roi[1] * spatial_scale).astype("int32") + roi_start_h = te.round(roi[2] * spatial_scale).astype("int32") + roi_end_w = te.round(roi[3] * spatial_scale).astype("int32") + roi_end_h = te.round(roi[4] * spatial_scale).astype("int32") + + roi_h = te.max(roi_end_h - roi_start_h + 1, tvm.tirx.const(1, "int32")) + roi_w = te.max(roi_end_w - roi_start_w + 1, tvm.tirx.const(1, "int32")) + + bin_h = tvm.tirx.Cast(roi_dtype, roi_h) / tvm.tirx.const(float(pooled_size_h), roi_dtype) + bin_w = tvm.tirx.Cast(roi_dtype, roi_w) / tvm.tirx.const(float(pooled_size_w), roi_dtype) + + hstart = te.floor(tvm.tirx.Cast(roi_dtype, ph) * bin_h).astype("int32") + wstart = te.floor(tvm.tirx.Cast(roi_dtype, pw) * bin_w).astype("int32") + hend = te.ceil(tvm.tirx.Cast(roi_dtype, ph + 1) * bin_h).astype("int32") + wend = te.ceil(tvm.tirx.Cast(roi_dtype, pw + 1) * bin_w).astype("int32") + + hstart = te.min(te.max(hstart + roi_start_h, 0), height) + hend = te.min(te.max(hend + roi_start_h, 0), height) + wstart = te.min(te.max(wstart + roi_start_w, 0), width) + wend = te.min(te.max(wend + roi_start_w, 0), width) + return hstart, hend, wstart, wend + + def _sample(i, c, ph, pw): + roi = rois[i] + batch_index = roi[0].astype("int32") + hstart, hend, wstart, wend = _bin_bounds(i, ph, pw) + valid = tvm.tirx.all(hstart <= rh, rh < hend, wstart <= rw, rw < wend) + return tvm.tirx.if_then_else(valid, data[batch_index, c, rh, rw], neg_inf) + + def _is_empty(i, ph, pw): + hstart, hend, wstart, wend = _bin_bounds(i, ph, pw) + return tvm.tirx.any(hend <= hstart, wend <= wstart) + + rh = te.reduce_axis((0, height), name="rh") + rw = te.reduce_axis((0, width), name="rw") + pooled = te.compute( + (num_roi, channel, pooled_size_h, pooled_size_w), + lambda i, c, ph, pw: te.max(_sample(i, c, ph, pw), axis=[rh, rw]), + tag="pool,roi_pool_nchw", + ) + + return te.compute( + (num_roi, channel, pooled_size_h, pooled_size_w), + lambda i, c, ph, pw: tvm.tirx.if_then_else( + _is_empty(i, ph, pw), zero, pooled[i, c, ph, pw] + ), + ) + + +def roi_pool(data, rois, pooled_size, spatial_scale, layout="NCHW"): + """ROI pool operator.""" + if layout == "NCHW": + return roi_pool_nchw(data, rois, pooled_size, spatial_scale) + raise ValueError(f"Unsupported layout for roi_pool: {layout}") diff --git a/src/relax/op/vision/roi_pool.cc b/src/relax/op/vision/roi_pool.cc new file mode 100644 index 000000000000..93eddb04cb8e --- /dev/null +++ b/src/relax/op/vision/roi_pool.cc @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file roi_pool.cc + * \brief ROI Pool operators. + */ + +#include "roi_pool.h" + +#include + +#include + +namespace tvm { +namespace relax { + +TVM_FFI_STATIC_INIT_BLOCK() { ROIPoolAttrs::RegisterReflection(); } + +Expr roi_pool(Expr data, Expr rois, ffi::Array pooled_size, double spatial_scale, + ffi::String layout) { + if (pooled_size.size() == 1) { + pooled_size.push_back(pooled_size[0]); + } + TVM_FFI_ICHECK_EQ(pooled_size.size(), 2) + << "The input pooled_size length is expected to be 2. However, the given pooled_size is " + << pooled_size; + + auto attrs = ffi::make_object(); + attrs->pooled_size = std::move(pooled_size); + attrs->spatial_scale = spatial_scale; + attrs->layout = layout; + + static const Op& op = Op::Get("relax.vision.roi_pool"); + return Call(op, {std::move(data), std::move(rois)}, Attrs(attrs), {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.vision.roi_pool", roi_pool); +} + +StructInfo InferStructInfoROIPool(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 2) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ROIPool expects two arguments, while the given number of arguments is " + << call->args.size()); + } + + const auto* data_sinfo = GetStructInfoAs(call->args[0]); + const auto* rois_sinfo = GetStructInfoAs(call->args[1]); + if (data_sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ROIPool expects the input data to be a Tensor, while the given data is " + << call->args[0]->GetTypeKey()); + } + if (rois_sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ROIPool expects the rois to be a Tensor, while the given rois is " + << call->args[1]->GetTypeKey()); + } + if (!data_sinfo->IsUnknownNdim() && data_sinfo->ndim != 4) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ROIPool expects the input data to be 4-D, while the given data has ndim " + << data_sinfo->ndim); + } + if (!rois_sinfo->IsUnknownNdim() && rois_sinfo->ndim != 2) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ROIPool expects the rois tensor to be 2-D, while the given rois has ndim " + << rois_sinfo->ndim); + } + + const auto* attrs = call->attrs.as(); + TVM_FFI_ICHECK(attrs != nullptr) << "Invalid ROIPool attrs"; + if (attrs->layout != "NCHW") { + ctx->ReportFatal(Diagnostic::Error(call) + << "ROIPool only supports NCHW layout, but got " << attrs->layout); + } + + const auto* rois_shape = rois_sinfo->shape.as(); + if (rois_shape != nullptr) { + const auto* last_dim = rois_shape->values[1].as(); + if (last_dim != nullptr && last_dim->value != 5) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ROIPool expects rois to have shape (num_roi, 5), but got last " + "dimension " + << last_dim->value); + } + } + + if (data_sinfo->shape.as() == nullptr || rois_shape == nullptr) { + return TensorStructInfo(data_sinfo->dtype, 4, data_sinfo->vdevice); + } + + ffi::Array data_shape = data_sinfo->shape.as()->values; + ffi::Array out_shape = {rois_shape->values[0], data_shape[1], + Integer(attrs->pooled_size[0]), Integer(attrs->pooled_size[1])}; + return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); +} + +TVM_REGISTER_OP("relax.vision.roi_pool") + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("rois", "Tensor", + "The input rois with shape (num_roi, 5) in [batch_idx, x1, y1, x2, y2] format.") + .set_attr("FInferStructInfo", InferStructInfoROIPool) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/vision/roi_pool.h b/src/relax/op/vision/roi_pool.h new file mode 100644 index 000000000000..738dbee0d836 --- /dev/null +++ b/src/relax/op/vision/roi_pool.h @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file roi_pool.h + * \brief The functions to make Relax ROI Pool operator calls. + */ + +#ifndef TVM_RELAX_OP_VISION_ROI_POOL_H_ +#define TVM_RELAX_OP_VISION_ROI_POOL_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! \brief ROI Pool operator. */ +Expr roi_pool(Expr data, Expr rois, ffi::Array pooled_size, double spatial_scale, + ffi::String layout); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_VISION_ROI_POOL_H_ diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index e2067bad2379..6690bf23f6bc 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -4416,5 +4416,51 @@ def test_if_nested(): ) +@pytest.mark.parametrize( + ("pooled_shape", "rois"), + [ + ((1, 1), np.array([[0.0, 1.0, 1.0, 6.0, 6.0], [0.0, 0.0, 0.0, 7.0, 7.0]], dtype="float32")), + ( + (2, 3), + np.array([[0.0, 1.2, 0.5, 6.8, 7.0], [0.0, -1.0, 2.0, 3.5, 5.2]], dtype="float32"), + ), + ( + (2, 2), + np.array( + [[0.0, 100.0, 100.0, 110.0, 110.0], [0.0, 1.0, 1.0, 6.0, 6.0]], dtype="float32" + ), + ), + ], +) +def test_max_roi_pool(pooled_shape, rois): + x_shape = [1, 4, 8, 8] + out_shape = [2, 4, pooled_shape[0], pooled_shape[1]] + + node = helper.make_node( + "MaxRoiPool", + inputs=["X", "rois"], + outputs=["Y"], + pooled_shape=pooled_shape, + spatial_scale=1.0, + ) + + graph = helper.make_graph( + [node], + "max_roi_pool_test", + inputs=[ + helper.make_tensor_value_info("X", TensorProto.FLOAT, x_shape), + helper.make_tensor_value_info("rois", TensorProto.FLOAT, [2, 5]), + ], + outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, out_shape)], + ) + + model = helper.make_model(graph, producer_name="max_roi_pool_test") + inputs = { + "X": rg.standard_normal(size=x_shape).astype("float32"), + "rois": rois, + } + check_correctness(model, inputs=inputs, opset=16, rtol=1e-5, atol=1e-5) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_op_vision.py b/tests/python/relax/test_op_vision.py index 6d04a796ca9b..b597b325f4fe 100644 --- a/tests/python/relax/test_op_vision.py +++ b/tests/python/relax/test_op_vision.py @@ -1050,6 +1050,96 @@ def test_nms_e2e_index_remap(): np.testing.assert_array_equal(ref_valid_box_count, np.array([[3]], dtype="int32")) +def test_roi_pool_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + rois = relax.Var("rois", R.Tensor((4, 5), "float32")) + assert relax.op.vision.roi_pool(x, rois, (7, 7), 1.0).op == Op.get("relax.vision.roi_pool") + + +def test_roi_pool_infer_struct_info(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + rois = relax.Var("rois", R.Tensor((5, 5), "float32")) + + _check_inference( + bb, + relax.op.vision.roi_pool(x, rois, (7, 5), 0.25), + relax.TensorStructInfo((5, 3, 7, 5), "float32"), + ) + + +def test_roi_pool_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + n = tirx.Var("n", "int64") + c = tirx.Var("c", "int64") + h = tirx.Var("h", "int64") + w = tirx.Var("w", "int64") + num_roi = tirx.Var("num_roi", "int64") + + x = relax.Var("x", R.Tensor((n, c, h, w), "float32")) + rois = relax.Var("rois", R.Tensor((num_roi, 5), "float32")) + + _check_inference( + bb, + relax.op.vision.roi_pool(x, rois, (7, 7), 0.5), + relax.TensorStructInfo((num_roi, c, 7, 7), "float32"), + ) + + +def test_roi_pool_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + rois0 = relax.Var("rois", R.Tensor((4,), "float32")) + rois1 = relax.Var("rois", R.Tensor((4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.roi_pool(x0, rois1, (7, 7), 1.0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.roi_pool(x1, rois0, (7, 7), 1.0)) + + +def test_roi_pool_wrong_rois_last_dim(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + rois = relax.Var("rois", R.Tensor((4, 4), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.roi_pool(x, rois, (7, 7), 1.0)) + + +def test_roi_pool_wrong_layout(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + rois = relax.Var("rois", R.Tensor((4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.roi_pool(x, rois, (7, 7), 1.0, layout="NHWC")) + + +def test_roi_pool_legalize(): + @tvm.script.ir_module + class ROIPool: + @R.function + def main( + x: R.Tensor((1, 2, 8, 8), "float32"), + rois: R.Tensor((2, 5), "float32"), + ) -> R.Tensor((2, 2, 3, 2), "float32"): + gv: R.Tensor((2, 2, 3, 2), "float32") = R.vision.roi_pool( + x, + rois, + pooled_size=(3, 2), + spatial_scale=1.0, + layout="NCHW", + ) + return gv + + mod = LegalizeOps()(ROIPool) + assert "call_tir" in str(mod) + tvm.ir.assert_structural_equal( + mod["main"].ret_struct_info, + relax.TensorStructInfo((2, 2, 3, 2), "float32"), + ) def test_all_class_non_max_suppression_infer_struct_info(): bb = relax.BlockBuilder() batch_size, num_classes, num_boxes = 10, 8, 5 @@ -1201,12 +1291,9 @@ def test_multibox_transform_loc_op_correctness(): cls = relax.Var("cls", R.Tensor((1, 5, 10), "float32")) loc = relax.Var("loc", R.Tensor((1, 40), "float32")) anc = relax.Var("anc", R.Tensor((1, 10, 4), "float32")) - assert ( - relax.op.vision.multibox_transform_loc( - cls, loc, anc, False, 0.0, (1.0, 1.0, 1.0, 1.0), True - ).op - == Op.get("relax.vision.multibox_transform_loc") - ) + assert relax.op.vision.multibox_transform_loc( + cls, loc, anc, False, 0.0, (1.0, 1.0, 1.0, 1.0), True + ).op == Op.get("relax.vision.multibox_transform_loc") def test_multibox_transform_loc_infer_struct_info():