From e8683a3cb18c0910dbb4684122c5670d08349c37 Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <792179245@qq.com> Date: Fri, 27 Mar 2026 04:18:04 +0000 Subject: [PATCH 1/4] [Relax][TOPI] Add relax.vision.multibox_transform_loc for SSD/TFLite box decode Introduce relax.vision.multibox_transform_loc with MultiboxTransformLocAttrs: decode center-size offsets against ltrb priors, softmax on class logits, and optional clip, threshold masking, and background score zeroing. Register the C++ op with FInferStructInfo checks for shapes and dtypes (including batch and 4*N consistency). Legalize to topi.vision.multibox_transform_loc. Add tests for struct inference, invalid inputs, Legalize+e2e on LLVM, attribute branches, and TVMScript roundtrip. Add a standalone numpy reference under topi/testing (not exported from tvm.topi.testing to avoid pulling scipy). Update TFLite frontend NotImplementedError text for DETECTION_POSTPROCESS and NON_MAX_SUPPRESSION_V5 to note multibox is available and link tracking issue #18928. --- include/tvm/relax/attrs/vision.h | 22 ++ .../relax/frontend/tflite/tflite_frontend.py | 12 +- python/tvm/relax/op/__init__.py | 2 +- python/tvm/relax/op/op_attrs.py | 5 + python/tvm/relax/op/vision/__init__.py | 1 + .../relax/op/vision/multibox_transform_loc.py | 77 ++++++ .../relax/transform/legalize_ops/vision.py | 24 ++ .../testing/multibox_transform_loc_python.py | 72 ++++++ python/tvm/topi/vision/__init__.py | 1 + .../tvm/topi/vision/multibox_transform_loc.py | 125 ++++++++++ src/relax/op/vision/multibox_transform_loc.cc | 189 +++++++++++++++ src/relax/op/vision/multibox_transform_loc.h | 42 ++++ tests/python/relax/test_op_vision.py | 228 ++++++++++++++++++ .../relax/test_tvmscript_parser_op_vision.py | 42 ++++ 14 files changed, 835 insertions(+), 7 deletions(-) create mode 100644 python/tvm/relax/op/vision/multibox_transform_loc.py create mode 100644 python/tvm/topi/testing/multibox_transform_loc_python.py create mode 100644 python/tvm/topi/vision/multibox_transform_loc.py create mode 100644 src/relax/op/vision/multibox_transform_loc.cc create mode 100644 src/relax/op/vision/multibox_transform_loc.h diff --git a/include/tvm/relax/attrs/vision.h b/include/tvm/relax/attrs/vision.h index 59a1dd7314fc..c73ac3b6b556 100644 --- a/include/tvm/relax/attrs/vision.h +++ b/include/tvm/relax/attrs/vision.h @@ -73,6 +73,28 @@ struct ROIAlignAttrs : public AttrsNodeReflAdapter { TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIAlignAttrs", ROIAlignAttrs, BaseAttrsNode); }; // struct ROIAlignAttrs +/*! \brief Attributes for multibox_transform_loc (SSD / TFLite-style box decode). */ +struct MultiboxTransformLocAttrs : public AttrsNodeReflAdapter { + bool clip; + double threshold; + ffi::Array variances; + bool keep_background; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("clip", &MultiboxTransformLocAttrs::clip, "Clip decoded ymin,xmin,ymax,xmax to [0,1].") + .def_ro("threshold", &MultiboxTransformLocAttrs::threshold, + "After softmax, zero scores strictly below this value.") + .def_ro("variances", &MultiboxTransformLocAttrs::variances, + "(x,y,w,h) scales = TFLite 1/x_scale,1/y_scale,1/w_scale,1/h_scale on encodings.") + .def_ro("keep_background", &MultiboxTransformLocAttrs::keep_background, + "If false, force output scores[:,0,:] to 0 (background class)."); + } + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.MultiboxTransformLocAttrs", + MultiboxTransformLocAttrs, BaseAttrsNode); +}; // struct MultiboxTransformLocAttrs + } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 0abd700562e2..60cbc3178aee 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -3205,9 +3205,10 @@ def convert_dequantize(self, op): def convert_detection_postprocess(self, op): """Convert TFLite_Detection_PostProcess""" raise NotImplementedError( - "DETECTION_POSTPROCESS requires vision ops (multibox_transform_loc, " - "non_max_suppression, get_valid_counts) not yet available in Relax. " - "See https://github.com/apache/tvm/issues/XXXX" + "DETECTION_POSTPROCESS is not wired in this frontend yet: it still needs " + "Relax NMS / get_valid_counts / related vision helpers (see dead code below). " + "relax.vision.multibox_transform_loc exists; tracking: " + "https://github.com/apache/tvm/issues/18928" ) flexbuffer = op.CustomOptionsAsNumpy().tobytes() custom_options = FlexBufferDecoder(flexbuffer).decode() @@ -3340,9 +3341,8 @@ def convert_nms_v5(self, op): """Convert TFLite NonMaxSuppressionV5""" # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/non-max-suppression-v5 raise NotImplementedError( - "NON_MAX_SUPPRESSION_V5 requires vision ops (get_valid_counts, " - "non_max_suppression) not yet available in Relax. " - "See https://github.com/apache/tvm/issues/XXXX" + "NON_MAX_SUPPRESSION_V5 is not wired in this frontend yet (needs get_valid_counts, " + "non_max_suppression, etc.). Tracking: https://github.com/apache/tvm/issues/18928" ) input_tensors = self.get_input_tensors(op) diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 0bc3f6578432..ee1a2c24206e 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -157,7 +157,7 @@ tanh, trunc, ) -from .vision import all_class_non_max_suppression, roi_align +from .vision import all_class_non_max_suppression, multibox_transform_loc, roi_align def _register_op_make(): diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index a3b6544dcc6e..e8c91f04b459 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -251,6 +251,11 @@ class ROIAlignAttrs(Attrs): """Attributes for vision.roi_align""" +@tvm_ffi.register_object("relax.attrs.MultiboxTransformLocAttrs") +class MultiboxTransformLocAttrs(Attrs): + """Attributes for vision.multibox_transform_loc""" + + @tvm_ffi.register_object("relax.attrs.Conv1DAttrs") class Conv1DAttrs(Attrs): """Attributes for nn.conv1d""" diff --git a/python/tvm/relax/op/vision/__init__.py b/python/tvm/relax/op/vision/__init__.py index 76d9ea35a11c..58266c5b2add 100644 --- a/python/tvm/relax/op/vision/__init__.py +++ b/python/tvm/relax/op/vision/__init__.py @@ -17,5 +17,6 @@ # under the License. """VISION operators.""" +from .multibox_transform_loc import * from .nms import * from .roi_align import * diff --git a/python/tvm/relax/op/vision/multibox_transform_loc.py b/python/tvm/relax/op/vision/multibox_transform_loc.py new file mode 100644 index 000000000000..e190dbcf835d --- /dev/null +++ b/python/tvm/relax/op/vision/multibox_transform_loc.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Multibox location transform for object detection.""" + +from . import _ffi_api + + +def multibox_transform_loc( + cls_pred, + loc_pred, + anchor, + clip=False, + threshold=0.0, + variances=(1.0, 1.0, 1.0, 1.0), + keep_background=True, +): + """SSD / TFLite-style decode: priors + offsets → boxes; logits → softmax scores. + + Box decode follows TFLite ``DecodeCenterSizeBoxes``; expected tensor layout matches + ``tflite_frontend.convert_detection_postprocess`` (loc reorder yxhw→xywh, anchor ltrb). + + Parameters + ---------- + cls_pred : relax.Expr + ``[B, C, N]`` class logits (pre-softmax). + loc_pred : relax.Expr + ``[B, 4*N]`` per-anchor encodings as ``(x,y,w,h)`` after reorder (see above). + anchor : relax.Expr + ``[1, N, 4]`` priors: ``(left, top, right, bottom)``. + clip : bool + If True, clip ``ymin,xmin,ymax,xmax`` to ``[0, 1]``. + threshold : float + After softmax, multiply scores by mask ``(score >= threshold)``. + variances : tuple of 4 floats + ``(x,y,w,h)`` = TFLite ``1/x_scale, 1/y_scale, 1/w_scale, 1/h_scale``. + keep_background : bool + If False, set output scores at class index 0 to zero. + + Returns + ------- + result : relax.Expr + Tuple ``(boxes, scores)``: ``boxes`` is ``[B, N, 4]`` as ``(ymin,xmin,ymax,xmax)``; + ``scores`` is ``[B, C, N]`` softmax, post-processed like the implementation. + + Notes + ----- + **Shape/dtype (checked in ``FInferStructInfo`` when static):** + + - ``cls_pred``: 3-D; ``loc_pred``: 2-D; ``anchor``: 3-D. + - ``cls_pred``, ``loc_pred``, ``anchor`` dtypes must match. + - ``N = cls_pred.shape[2]``; ``loc_pred.shape[1] == 4*N``; ``anchor.shape == [1,N,4]``. + - ``loc_pred.shape[1]`` must be divisible by 4. + - ``cls_pred.shape[0]`` must equal ``loc_pred.shape[0]`` (batch). + """ + return _ffi_api.multibox_transform_loc( + cls_pred, + loc_pred, + anchor, + clip, + threshold, + variances, + keep_background, + ) diff --git a/python/tvm/relax/transform/legalize_ops/vision.py b/python/tvm/relax/transform/legalize_ops/vision.py index 7a1e305f39f0..28367a67a361 100644 --- a/python/tvm/relax/transform/legalize_ops/vision.py +++ b/python/tvm/relax/transform/legalize_ops/vision.py @@ -118,3 +118,27 @@ def _roi_align(bb: BlockBuilder, call: Call) -> Expr: aligned=call.attrs.aligned, layout=call.attrs.layout, ) + + +@register_legalize("relax.vision.multibox_transform_loc") +def _multibox_transform_loc(bb: BlockBuilder, call: Call) -> Expr: + variances = tuple(float(x) for x in call.attrs.variances) + + def _te(cls_pred, loc_pred, anchor): + return topi.vision.multibox_transform_loc( + cls_pred, + loc_pred, + anchor, + variances, + clip=call.attrs.clip, + threshold=call.attrs.threshold, + keep_background=call.attrs.keep_background, + ) + + return bb.call_te( + _te, + call.args[0], + call.args[1], + call.args[2], + primfunc_name_hint="multibox_transform_loc", + ) diff --git a/python/tvm/topi/testing/multibox_transform_loc_python.py b/python/tvm/topi/testing/multibox_transform_loc_python.py new file mode 100644 index 000000000000..8b9d81505d73 --- /dev/null +++ b/python/tvm/topi/testing/multibox_transform_loc_python.py @@ -0,0 +1,72 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Numpy reference for multibox_transform_loc.""" + +import numpy as np + + +def _softmax(x, axis): + x_max = np.max(x, axis=axis, keepdims=True) + exp = np.exp(x - x_max) + return exp / np.sum(exp, axis=axis, keepdims=True) + + +def multibox_transform_loc_python( + cls_pred, + loc_pred, + anchor, + variances, + clip=False, + threshold=0.0, + keep_background=True, +): + """Reference implementation aligned with ``topi.vision.multibox_transform_loc``.""" + B, C, N = cls_pred.shape + loc = loc_pred.reshape(B, N, 4) + scores = _softmax(cls_pred.astype("float64"), axis=1).astype(np.float32) + if threshold > 0.0: + scores = np.where(scores >= threshold, scores, 0.0).astype(np.float32) + if not keep_background: + scores = scores.copy() + scores[:, 0, :] = 0.0 + + vx, vy, vw, vh = variances + boxes = np.zeros((B, N, 4), dtype=np.float32) + for b in range(B): + for a in range(N): + l, t, r, br = anchor[0, a, :] + ay = (t + br) * 0.5 + ax = (l + r) * 0.5 + ah = br - t + aw = r - l + ex, ey, ew, eh = loc[b, a, :] + ycenter = ey * vy * ah + ay + xcenter = ex * vx * aw + ax + half_h = 0.5 * np.exp(eh * vh) * ah + half_w = 0.5 * np.exp(ew * vw) * aw + ymin = ycenter - half_h + xmin = xcenter - half_w + ymax = ycenter + half_h + xmax = xcenter + half_w + if clip: + ymin = float(np.clip(ymin, 0.0, 1.0)) + xmin = float(np.clip(xmin, 0.0, 1.0)) + ymax = float(np.clip(ymax, 0.0, 1.0)) + xmax = float(np.clip(xmax, 0.0, 1.0)) + boxes[b, a, :] = (ymin, xmin, ymax, xmax) + return boxes, scores diff --git a/python/tvm/topi/vision/__init__.py b/python/tvm/topi/vision/__init__.py index 75725a8a4bea..cb0467c98cd4 100644 --- a/python/tvm/topi/vision/__init__.py +++ b/python/tvm/topi/vision/__init__.py @@ -17,5 +17,6 @@ # under the License. """Vision operators.""" +from .multibox_transform_loc import * from .nms import * from .roi_align import * diff --git a/python/tvm/topi/vision/multibox_transform_loc.py b/python/tvm/topi/vision/multibox_transform_loc.py new file mode 100644 index 000000000000..43c51392e568 --- /dev/null +++ b/python/tvm/topi/vision/multibox_transform_loc.py @@ -0,0 +1,125 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Multibox location transform (SSD / TFLite DetectionPostProcess decode).""" + +import tvm +from tvm import te, topi + + +def multibox_transform_loc( + cls_pred, + loc_pred, + anchor, + variances, + clip=False, + threshold=0.0, + keep_background=True, +): + """TFLite ``DecodeCenterSizeBoxes``-style decode + softmax score post-process. + + Inputs must match Relax op contracts: ``cls_pred [B,C,N]``, ``loc_pred [B,4*N]``, + ``anchor [1,N,4]`` ltrb; per-anchor loc order ``(x,y,w,h)`` after yxhw→xywh reorder. + + Parameters + ---------- + cls_pred : te.Tensor + ``[B, C, N]`` logits. + loc_pred : te.Tensor + ``[B, 4*N]`` encodings ``(x,y,w,h)`` per anchor. + anchor : te.Tensor + ``[1, N, 4]`` ``(left, top, right, bottom)``. + variances : tuple of 4 float + ``(x,y,w,h)`` = ``1/x_scale, 1/y_scale, 1/w_scale, 1/h_scale`` (TFLite). + clip : bool + Clip ``ymin,xmin,ymax,xmax`` to ``[0,1]``. + threshold : float + After softmax: ``scores *= (scores >= threshold)``. + keep_background : bool + If False: ``scores[:,0,:] = 0``. + + Returns + ------- + boxes : te.Tensor + ``[B, N, 4]`` as ``(ymin,xmin,ymax,xmax)``. + scores : te.Tensor + ``[B, C, N]`` softmax, then threshold mask and optional background zero. + """ + dtype = cls_pred.dtype + B = cls_pred.shape[0] + num_anchors = cls_pred.shape[2] + loc_reshaped = topi.reshape(loc_pred, [B, num_anchors, 4]) + + vx = tvm.tirx.const(float(variances[0]), dtype) + vy = tvm.tirx.const(float(variances[1]), dtype) + vw = tvm.tirx.const(float(variances[2]), dtype) + vh = tvm.tirx.const(float(variances[3]), dtype) + half = tvm.tirx.const(0.5, dtype) + zero = tvm.tirx.const(0.0, dtype) + one = tvm.tirx.const(1.0, dtype) + th = tvm.tirx.const(float(threshold), dtype) + + def decode_bbox(b, a, k): + l = anchor[0, a, 0] + t = anchor[0, a, 1] + r = anchor[0, a, 2] + br = anchor[0, a, 3] + ay = (t + br) * half + ax = (l + r) * half + ah = br - t + aw = r - l + ex = loc_reshaped[b, a, 0] + ey = loc_reshaped[b, a, 1] + ew = loc_reshaped[b, a, 2] + eh = loc_reshaped[b, a, 3] + ycenter = ey * vy * ah + ay + xcenter = ex * vx * aw + ax + half_h = half * te.exp(eh * vh) * ah + half_w = half * te.exp(ew * vw) * aw + ymin = ycenter - half_h + xmin = xcenter - half_w + ymax = ycenter + half_h + xmax = xcenter + half_w + if clip: + ymin = te.max(zero, te.min(one, ymin)) + xmin = te.max(zero, te.min(one, xmin)) + ymax = te.max(zero, te.min(one, ymax)) + xmax = te.max(zero, te.min(one, xmax)) + return te.if_then_else( + k == 0, + ymin, + te.if_then_else( + k == 1, + xmin, + te.if_then_else(k == 2, ymax, xmax), + ), + ) + + boxes = te.compute((B, num_anchors, 4), decode_bbox, name="multibox_boxes") + + scores = topi.nn.softmax(cls_pred, axis=1) + mask = topi.cast(topi.greater_equal(scores, th), dtype) + scores = scores * mask + if not keep_background: + + def zero_bg(b, c, n): + s = scores[b, c, n] + return te.if_then_else(c == 0, zero, s) + + scores = te.compute(scores.shape, zero_bg, name="multibox_scores_bg") + + return [boxes, scores] diff --git a/src/relax/op/vision/multibox_transform_loc.cc b/src/relax/op/vision/multibox_transform_loc.cc new file mode 100644 index 000000000000..0fd00a27b380 --- /dev/null +++ b/src/relax/op/vision/multibox_transform_loc.cc @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file multibox_transform_loc.cc + * \brief Multibox transform (location decode) for object detection. + */ + +#include "multibox_transform_loc.h" + +#include +#include + +#include + +namespace tvm { +namespace relax { + +TVM_FFI_STATIC_INIT_BLOCK() { MultiboxTransformLocAttrs::RegisterReflection(); } + +Expr multibox_transform_loc(Expr cls_pred, Expr loc_pred, Expr anchor, bool clip, double threshold, + ffi::Array variances, bool keep_background) { + TVM_FFI_ICHECK_EQ(variances.size(), 4) + << "multibox_transform_loc: variances must be length 4 (x,y,w,h), got " << variances.size(); + + auto attrs = ffi::make_object(); + attrs->clip = clip; + attrs->threshold = threshold; + attrs->variances = std::move(variances); + attrs->keep_background = keep_background; + + static const Op& op = Op::Get("relax.vision.multibox_transform_loc"); + return Call(op, {std::move(cls_pred), std::move(loc_pred), std::move(anchor)}, Attrs(attrs), {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.vision.multibox_transform_loc", multibox_transform_loc); +} + +StructInfo InferStructInfoMultiboxTransformLoc(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 3) { + ctx->ReportFatal(Diagnostic::Error(call) + << "multibox_transform_loc: expected 3 inputs (cls_pred, loc_pred, anchor), got " + << call->args.size()); + } + + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); + const auto cls_sinfo = input_sinfo[0]; + const auto loc_sinfo = input_sinfo[1]; + const auto anchor_sinfo = input_sinfo[2]; + + if (!cls_sinfo->IsUnknownNdim() && cls_sinfo->ndim != 3) { + ctx->ReportFatal(Diagnostic::Error(call) + << "multibox_transform_loc: cls_pred must be 3-D [B, num_classes, N], got " + "ndim " + << cls_sinfo->ndim); + } + if (!loc_sinfo->IsUnknownNdim() && loc_sinfo->ndim != 2) { + ctx->ReportFatal(Diagnostic::Error(call) + << "multibox_transform_loc: loc_pred must be 2-D [B, 4*N], got ndim " + << loc_sinfo->ndim); + } + if (!anchor_sinfo->IsUnknownNdim() && anchor_sinfo->ndim != 3) { + ctx->ReportFatal(Diagnostic::Error(call) + << "multibox_transform_loc: anchor must be 3-D [1, N, 4] ltrb, got ndim " + << anchor_sinfo->ndim); + } + + if (!cls_sinfo->IsUnknownDtype() && !loc_sinfo->IsUnknownDtype() && cls_sinfo->dtype != loc_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "multibox_transform_loc: cls_pred and loc_pred dtype must match, got " + << cls_sinfo->dtype << " vs " << loc_sinfo->dtype); + } + if (!cls_sinfo->IsUnknownDtype() && !anchor_sinfo->IsUnknownDtype() && + cls_sinfo->dtype != anchor_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "multibox_transform_loc: cls_pred and anchor dtype must match, got " + << cls_sinfo->dtype << " vs " << anchor_sinfo->dtype); + } + + auto vdev = cls_sinfo->vdevice; + const auto* cls_shape = cls_sinfo->shape.as(); + const auto* loc_shape = loc_sinfo->shape.as(); + const auto* anchor_shape = anchor_sinfo->shape.as(); + + if (loc_shape != nullptr) { + const auto* loc_dim1 = loc_shape->values[1].as(); + if (loc_dim1 != nullptr && loc_dim1->value % 4 != 0) { + ctx->ReportFatal(Diagnostic::Error(call) + << "multibox_transform_loc: loc_pred.shape[1] must be divisible by 4, got " + << loc_dim1->value); + } + } + + if (cls_shape != nullptr && loc_shape != nullptr) { + const auto* cls_b = cls_shape->values[0].as(); + const auto* loc_b = loc_shape->values[0].as(); + if (cls_b != nullptr && loc_b != nullptr && cls_b->value != loc_b->value) { + ctx->ReportFatal(Diagnostic::Error(call) + << "multibox_transform_loc: cls_pred.shape[0] must match loc_pred.shape[0], " + "got B=" + << cls_b->value << " vs " << loc_b->value); + } + } + + if (anchor_shape != nullptr) { + const auto* anchor_batch = anchor_shape->values[0].as(); + if (anchor_batch != nullptr && anchor_batch->value != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "multibox_transform_loc: anchor.shape[0] must be 1, got " + << anchor_batch->value); + } + const auto* anchor_last = anchor_shape->values[2].as(); + if (anchor_last != nullptr && anchor_last->value != 4) { + ctx->ReportFatal(Diagnostic::Error(call) + << "multibox_transform_loc: anchor.shape[2] must be 4 (ltrb), got " + << anchor_last->value); + } + } + + if (cls_shape == nullptr) { + ffi::Array fields = {TensorStructInfo(cls_sinfo->dtype, 3, vdev), + TensorStructInfo(cls_sinfo->dtype, 3, vdev)}; + return TupleStructInfo(fields); + } + + const auto& batch = cls_shape->values[0]; + const auto& num_classes = cls_shape->values[1]; + const auto& num_anchors = cls_shape->values[2]; + + if (loc_shape != nullptr) { + const auto* num_anchors_imm = num_anchors.as(); + const auto* loc_dim1 = loc_shape->values[1].as(); + if (num_anchors_imm != nullptr && loc_dim1 != nullptr && + loc_dim1->value != num_anchors_imm->value * 4) { + ctx->ReportFatal(Diagnostic::Error(call) + << "multibox_transform_loc: loc_pred.shape[1] must equal 4*N with " + "N=cls_pred.shape[2], got loc_dim=" + << loc_dim1->value << ", N=" << num_anchors_imm->value); + } + } + if (anchor_shape != nullptr) { + const auto* num_anchors_imm = num_anchors.as(); + const auto* anchor_num_anchors = anchor_shape->values[1].as(); + if (num_anchors_imm != nullptr && anchor_num_anchors != nullptr && + anchor_num_anchors->value != num_anchors_imm->value) { + ctx->ReportFatal(Diagnostic::Error(call) + << "multibox_transform_loc: anchor.shape[1] must equal N=cls_pred.shape[2], " + "got anchor_N=" + << anchor_num_anchors->value << ", N=" << num_anchors_imm->value); + } + } + + ffi::Array boxes_shape = {batch, num_anchors, Integer(4)}; + ffi::Array scores_shape = {batch, num_classes, num_anchors}; + ffi::Array fields = { + TensorStructInfo(ShapeExpr(boxes_shape), cls_sinfo->dtype, vdev), + TensorStructInfo(ShapeExpr(scores_shape), cls_sinfo->dtype, vdev)}; + return TupleStructInfo(fields); +} + +TVM_REGISTER_OP("relax.vision.multibox_transform_loc") + .set_attrs_type() + .set_num_inputs(3) + .add_argument("cls_pred", "Tensor", "[B,C,N] class logits (pre-softmax).") + .add_argument("loc_pred", "Tensor", "[B,4*N] box encodings per anchor as (x,y,w,h) after yxhw→xywh.") + .add_argument("anchor", "Tensor", "[1,N,4] priors as ltrb (left,top,right,bottom).") + .set_attr("FInferStructInfo", InferStructInfoMultiboxTransformLoc) + .set_attr("FPurity", Bool(true)); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/vision/multibox_transform_loc.h b/src/relax/op/vision/multibox_transform_loc.h new file mode 100644 index 000000000000..726bc4c0e582 --- /dev/null +++ b/src/relax/op/vision/multibox_transform_loc.h @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file multibox_transform_loc.h + * \brief The functions to make Relax multibox_transform_loc operator calls. + */ + +#ifndef TVM_RELAX_OP_VISION_MULTIBOX_TRANSFORM_LOC_H_ +#define TVM_RELAX_OP_VISION_MULTIBOX_TRANSFORM_LOC_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! \brief Decode SSD box encodings and prepare class scores (TFLite-compatible). */ +Expr multibox_transform_loc(Expr cls_pred, Expr loc_pred, Expr anchor, bool clip, double threshold, + ffi::Array variances, bool keep_background); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_VISION_MULTIBOX_TRANSFORM_LOC_H_ diff --git a/tests/python/relax/test_op_vision.py b/tests/python/relax/test_op_vision.py index b902518b49bb..d39977b8c2e3 100644 --- a/tests/python/relax/test_op_vision.py +++ b/tests/python/relax/test_op_vision.py @@ -344,5 +344,233 @@ def main( tvm.testing.assert_allclose(selected_indices.shape, (num_total_detections, 3)) +def test_multibox_transform_loc_op_correctness(): + cls = relax.Var("cls", R.Tensor((1, 5, 10), "float32")) + loc = relax.Var("loc", R.Tensor((1, 40), "float32")) + anc = relax.Var("anc", R.Tensor((1, 10, 4), "float32")) + assert ( + relax.op.vision.multibox_transform_loc( + cls, loc, anc, False, 0.0, (1.0, 1.0, 1.0, 1.0), True + ).op + == Op.get("relax.vision.multibox_transform_loc") + ) + + +def test_multibox_transform_loc_infer_struct_info(): + bb = relax.BlockBuilder() + cls = relax.Var("cls", R.Tensor((2, 3, 5), "float32")) + loc = relax.Var("loc", R.Tensor((2, 20), "float32")) + anc = relax.Var("anc", R.Tensor((1, 5, 4), "float32")) + _check_inference( + bb, + relax.op.vision.multibox_transform_loc( + cls, loc, anc, False, 0.0, (0.1, 0.1, 0.2, 0.2), True + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 5, 4), "float32"), + relax.TensorStructInfo((2, 3, 5), "float32"), + ] + ), + ) + + +def test_multibox_transform_loc_wrong_cls_ndim(): + bb = relax.BlockBuilder() + cls = relax.Var("cls", R.Tensor((2, 3), "float32")) + loc = relax.Var("loc", R.Tensor((2, 20), "float32")) + anc = relax.Var("anc", R.Tensor((1, 5, 4), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc, anc)) + + +def test_multibox_transform_loc_wrong_shape_relation(): + bb = relax.BlockBuilder() + cls = relax.Var("cls", R.Tensor((2, 3, 5), "float32")) + loc = relax.Var("loc", R.Tensor((2, 19), "float32")) + anc = relax.Var("anc", R.Tensor((1, 5, 4), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc, anc)) + + +def test_multibox_transform_loc_wrong_anchor_shape(): + bb = relax.BlockBuilder() + cls = relax.Var("cls", R.Tensor((2, 3, 5), "float32")) + loc = relax.Var("loc", R.Tensor((2, 20), "float32")) + anc_bad_batch = relax.Var("anc_bad_batch", R.Tensor((2, 5, 4), "float32")) + anc_bad_last = relax.Var("anc_bad_last", R.Tensor((1, 5, 5), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc, anc_bad_batch)) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc, anc_bad_last)) + + +def test_multibox_transform_loc_wrong_dtype(): + bb = relax.BlockBuilder() + cls = relax.Var("cls", R.Tensor((2, 3, 5), "float32")) + loc = relax.Var("loc", R.Tensor((2, 20), "float16")) + anc = relax.Var("anc", R.Tensor((1, 5, 4), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc, anc)) + + +def test_multibox_transform_loc_wrong_batch(): + bb = relax.BlockBuilder() + cls = relax.Var("cls", R.Tensor((2, 3, 5), "float32")) + loc = relax.Var("loc", R.Tensor((1, 20), "float32")) + anc = relax.Var("anc", R.Tensor((1, 5, 4), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc, anc)) + + +def _multibox_ref_numpy( + cls_pred, loc_pred, anchor, variances, clip=False, threshold=0.0, keep_background=True +): + """Minimal numpy reference (avoids importing tvm.topi.testing which pulls scipy).""" + + def _softmax(x, axis): + x_max = np.max(x, axis=axis, keepdims=True) + exp = np.exp(x - x_max) + return exp / np.sum(exp, axis=axis, keepdims=True) + + B, C, N = cls_pred.shape + loc = loc_pred.reshape(B, N, 4) + scores = _softmax(cls_pred.astype("float64"), axis=1).astype(np.float32) + if threshold > 0.0: + scores = np.where(scores >= threshold, scores, 0.0).astype(np.float32) + if not keep_background: + scores = scores.copy() + scores[:, 0, :] = 0.0 + vx, vy, vw, vh = variances + boxes = np.zeros((B, N, 4), dtype=np.float32) + for b in range(B): + for a in range(N): + l, t, r, br = anchor[0, a, :] + ay = (t + br) * 0.5 + ax = (l + r) * 0.5 + ah = br - t + aw = r - l + ex, ey, ew, eh = loc[b, a, :] + ycenter = ey * vy * ah + ay + xcenter = ex * vx * aw + ax + half_h = 0.5 * np.exp(eh * vh) * ah + half_w = 0.5 * np.exp(ew * vw) * aw + ymin = ycenter - half_h + xmin = xcenter - half_w + ymax = ycenter + half_h + xmax = xcenter + half_w + if clip: + ymin = np.clip(ymin, 0.0, 1.0) + xmin = np.clip(xmin, 0.0, 1.0) + ymax = np.clip(ymax, 0.0, 1.0) + xmax = np.clip(xmax, 0.0, 1.0) + boxes[b, a, :] = (ymin, xmin, ymax, xmax) + return boxes, scores + + +def test_multibox_transform_loc_legalize_e2e(): + @tvm.script.ir_module + class Mod: + @R.function + def main( + cls: R.Tensor((1, 3, 5), "float32"), + loc: R.Tensor((1, 20), "float32"), + anc: R.Tensor((1, 5, 4), "float32"), + ) -> R.Tuple(R.Tensor((1, 5, 4), "float32"), R.Tensor((1, 3, 5), "float32")): + return R.vision.multibox_transform_loc( + cls, + loc, + anc, + clip=False, + threshold=0.0, + variances=(1.0, 1.0, 1.0, 1.0), + keep_background=True, + ) + + cls_data = np.random.randn(1, 3, 5).astype(np.float32) + loc_data = np.random.randn(1, 20).astype(np.float32) * 0.05 + anc_data = np.array( + [ + [ + [0.1, 0.1, 0.5, 0.5], + [0.2, 0.2, 0.6, 0.6], + [0.0, 0.0, 1.0, 1.0], + [0.3, 0.3, 0.7, 0.7], + [0.05, 0.05, 0.45, 0.45], + ] + ], + dtype=np.float32, + ) + + mod = LegalizeOps()(Mod) + exe = tvm.compile(mod, target="llvm") + vm = relax.VirtualMachine(exe, tvm.cpu()) + ref_b, ref_s = _multibox_ref_numpy(cls_data, loc_data, anc_data, (1.0, 1.0, 1.0, 1.0)) + out = vm["main"]( + tvm.runtime.tensor(cls_data, tvm.cpu()), + tvm.runtime.tensor(loc_data, tvm.cpu()), + tvm.runtime.tensor(anc_data, tvm.cpu()), + ) + tvm.testing.assert_allclose(out[0].numpy(), ref_b, rtol=1e-4, atol=1e-5) + tvm.testing.assert_allclose(out[1].numpy(), ref_s, rtol=1e-4, atol=1e-5) + + +def test_multibox_transform_loc_legalize_attr_branches(): + @tvm.script.ir_module + class Mod: + @R.function + def main( + cls: R.Tensor((1, 3, 4), "float32"), + loc: R.Tensor((1, 16), "float32"), + anc: R.Tensor((1, 4, 4), "float32"), + ) -> R.Tuple(R.Tensor((1, 4, 4), "float32"), R.Tensor((1, 3, 4), "float32")): + return R.vision.multibox_transform_loc( + cls, + loc, + anc, + clip=True, + threshold=0.4, + variances=(1.0, 1.0, 1.0, 1.0), + keep_background=False, + ) + + cls_data = np.array( + [[[2.0, 0.1, -0.5, 0.0], [0.2, 2.2, 0.3, -1.0], [0.1, 0.4, 2.0, 0.5]]], + dtype=np.float32, + ) + loc_data = np.array( + [[0.1, -0.2, 0.0, 0.0, -0.2, 0.1, 0.3, -0.1, 0.0, 0.0, 0.8, 0.8, 0.2, 0.2, -0.6, -0.6]], + dtype=np.float32, + ) + anc_data = np.array( + [[[0.1, 0.1, 0.5, 0.5], [0.2, 0.2, 0.6, 0.6], [0.0, 0.0, 1.0, 1.0], [0.4, 0.4, 1.2, 1.2]]], + dtype=np.float32, + ) + + mod = LegalizeOps()(Mod) + exe = tvm.compile(mod, target="llvm") + vm = relax.VirtualMachine(exe, tvm.cpu()) + ref_b, ref_s = _multibox_ref_numpy( + cls_data, + loc_data, + anc_data, + (1.0, 1.0, 1.0, 1.0), + clip=True, + threshold=0.4, + keep_background=False, + ) + out = vm["main"]( + tvm.runtime.tensor(cls_data, tvm.cpu()), + tvm.runtime.tensor(loc_data, tvm.cpu()), + tvm.runtime.tensor(anc_data, tvm.cpu()), + ) + boxes = out[0].numpy() + scores = out[1].numpy() + tvm.testing.assert_allclose(boxes, ref_b, rtol=1e-4, atol=1e-5) + tvm.testing.assert_allclose(scores, ref_s, rtol=1e-4, atol=1e-5) + assert np.all(boxes >= 0.0) and np.all(boxes <= 1.0) + tvm.testing.assert_allclose(scores[:, 0, :], np.zeros_like(scores[:, 0, :])) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_vision.py b/tests/python/relax/test_tvmscript_parser_op_vision.py index c4e8ff0c9d22..f053e3674493 100644 --- a/tests/python/relax/test_tvmscript_parser_op_vision.py +++ b/tests/python/relax/test_tvmscript_parser_op_vision.py @@ -75,6 +75,48 @@ def foo( _check(foo, bb.get()["foo"]) +def test_multibox_transform_loc(): + @R.function + def foo( + cls: R.Tensor((1, 3, 5), "float32"), + loc: R.Tensor((1, 20), "float32"), + anc: R.Tensor((1, 5, 4), "float32"), + ) -> R.Tuple(R.Tensor((1, 5, 4), "float32"), R.Tensor((1, 3, 5), "float32")): + gv: R.Tuple(R.Tensor((1, 5, 4), "float32"), R.Tensor((1, 3, 5), "float32")) = ( + R.vision.multibox_transform_loc( + cls, + loc, + anc, + clip=False, + threshold=0.0, + variances=(1.0, 1.0, 1.0, 1.0), + keep_background=True, + ) + ) + return gv + + cls = relax.Var("cls", R.Tensor((1, 3, 5), "float32")) + loc = relax.Var("loc", R.Tensor((1, 20), "float32")) + anc = relax.Var("anc", R.Tensor((1, 5, 4), "float32")) + + bb = relax.BlockBuilder() + with bb.function("foo", [cls, loc, anc]): + gv = bb.emit( + relax.op.vision.multibox_transform_loc( + cls, + loc, + anc, + clip=False, + threshold=0.0, + variances=(1.0, 1.0, 1.0, 1.0), + keep_background=True, + ) + ) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + def test_roi_align(): @R.function def foo( From f86ba69a0e47f8053c4d0f7933b028e87f32f409 Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <792179245@qq.com> Date: Fri, 27 Mar 2026 05:47:53 +0000 Subject: [PATCH 2/4] [Relax] multibox_transform_loc: simplify clip and Select --- python/tvm/topi/testing/multibox_transform_loc_python.py | 8 ++++---- python/tvm/topi/vision/multibox_transform_loc.py | 8 ++------ 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/python/tvm/topi/testing/multibox_transform_loc_python.py b/python/tvm/topi/testing/multibox_transform_loc_python.py index 8b9d81505d73..121690b68924 100644 --- a/python/tvm/topi/testing/multibox_transform_loc_python.py +++ b/python/tvm/topi/testing/multibox_transform_loc_python.py @@ -64,9 +64,9 @@ def multibox_transform_loc_python( ymax = ycenter + half_h xmax = xcenter + half_w if clip: - ymin = float(np.clip(ymin, 0.0, 1.0)) - xmin = float(np.clip(xmin, 0.0, 1.0)) - ymax = float(np.clip(ymax, 0.0, 1.0)) - xmax = float(np.clip(xmax, 0.0, 1.0)) + ymin = np.clip(ymin, 0.0, 1.0) + xmin = np.clip(xmin, 0.0, 1.0) + ymax = np.clip(ymax, 0.0, 1.0) + xmax = np.clip(xmax, 0.0, 1.0) boxes[b, a, :] = (ymin, xmin, ymax, xmax) return boxes, scores diff --git a/python/tvm/topi/vision/multibox_transform_loc.py b/python/tvm/topi/vision/multibox_transform_loc.py index 43c51392e568..ab965e798141 100644 --- a/python/tvm/topi/vision/multibox_transform_loc.py +++ b/python/tvm/topi/vision/multibox_transform_loc.py @@ -99,14 +99,10 @@ def decode_bbox(b, a, k): xmin = te.max(zero, te.min(one, xmin)) ymax = te.max(zero, te.min(one, ymax)) xmax = te.max(zero, te.min(one, xmax)) - return te.if_then_else( + return tvm.tirx.Select( k == 0, ymin, - te.if_then_else( - k == 1, - xmin, - te.if_then_else(k == 2, ymax, xmax), - ), + tvm.tirx.Select(k == 1, xmin, tvm.tirx.Select(k == 2, ymax, xmax)), ) boxes = te.compute((B, num_anchors, 4), decode_bbox, name="multibox_boxes") From d7f7adc15899d63d490ab1cb1471ee40df3f5807 Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <792179245@qq.com> Date: Sat, 28 Mar 2026 02:19:31 +0000 Subject: [PATCH 3/4] [Relax] multibox_transform_loc: review fixes and vision test hygiene - multibox_transform_loc.cc: wrap long lines, ASCII loc_pred doc (no Unicode arrow). - vision.h: wrap MultiboxTransformLoc clip reflection line to 100 cols. - Remove unused topi/testing/multibox_transform_loc_python.py; keep numpy ref in tests. - test_op_vision: loc_dim=24 vs 4*N infer case; @tvm.testing.requires_llvm on LLVM e2e (multibox + all_class_nms); add e2e with non-unity variances. --- include/tvm/relax/attrs/vision.h | 3 +- .../testing/multibox_transform_loc_python.py | 72 ------------------- src/relax/op/vision/multibox_transform_loc.cc | 9 ++- tests/python/relax/test_op_vision.py | 61 +++++++++++++++- 4 files changed, 66 insertions(+), 79 deletions(-) delete mode 100644 python/tvm/topi/testing/multibox_transform_loc_python.py diff --git a/include/tvm/relax/attrs/vision.h b/include/tvm/relax/attrs/vision.h index c73ac3b6b556..5189a1c03ec1 100644 --- a/include/tvm/relax/attrs/vision.h +++ b/include/tvm/relax/attrs/vision.h @@ -83,7 +83,8 @@ struct MultiboxTransformLocAttrs : public AttrsNodeReflAdapter() - .def_ro("clip", &MultiboxTransformLocAttrs::clip, "Clip decoded ymin,xmin,ymax,xmax to [0,1].") + .def_ro("clip", &MultiboxTransformLocAttrs::clip, + "Clip decoded ymin,xmin,ymax,xmax to [0,1].") .def_ro("threshold", &MultiboxTransformLocAttrs::threshold, "After softmax, zero scores strictly below this value.") .def_ro("variances", &MultiboxTransformLocAttrs::variances, diff --git a/python/tvm/topi/testing/multibox_transform_loc_python.py b/python/tvm/topi/testing/multibox_transform_loc_python.py deleted file mode 100644 index 121690b68924..000000000000 --- a/python/tvm/topi/testing/multibox_transform_loc_python.py +++ /dev/null @@ -1,72 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name -"""Numpy reference for multibox_transform_loc.""" - -import numpy as np - - -def _softmax(x, axis): - x_max = np.max(x, axis=axis, keepdims=True) - exp = np.exp(x - x_max) - return exp / np.sum(exp, axis=axis, keepdims=True) - - -def multibox_transform_loc_python( - cls_pred, - loc_pred, - anchor, - variances, - clip=False, - threshold=0.0, - keep_background=True, -): - """Reference implementation aligned with ``topi.vision.multibox_transform_loc``.""" - B, C, N = cls_pred.shape - loc = loc_pred.reshape(B, N, 4) - scores = _softmax(cls_pred.astype("float64"), axis=1).astype(np.float32) - if threshold > 0.0: - scores = np.where(scores >= threshold, scores, 0.0).astype(np.float32) - if not keep_background: - scores = scores.copy() - scores[:, 0, :] = 0.0 - - vx, vy, vw, vh = variances - boxes = np.zeros((B, N, 4), dtype=np.float32) - for b in range(B): - for a in range(N): - l, t, r, br = anchor[0, a, :] - ay = (t + br) * 0.5 - ax = (l + r) * 0.5 - ah = br - t - aw = r - l - ex, ey, ew, eh = loc[b, a, :] - ycenter = ey * vy * ah + ay - xcenter = ex * vx * aw + ax - half_h = 0.5 * np.exp(eh * vh) * ah - half_w = 0.5 * np.exp(ew * vw) * aw - ymin = ycenter - half_h - xmin = xcenter - half_w - ymax = ycenter + half_h - xmax = xcenter + half_w - if clip: - ymin = np.clip(ymin, 0.0, 1.0) - xmin = np.clip(xmin, 0.0, 1.0) - ymax = np.clip(ymax, 0.0, 1.0) - xmax = np.clip(xmax, 0.0, 1.0) - boxes[b, a, :] = (ymin, xmin, ymax, xmax) - return boxes, scores diff --git a/src/relax/op/vision/multibox_transform_loc.cc b/src/relax/op/vision/multibox_transform_loc.cc index 0fd00a27b380..75f0a05b85c1 100644 --- a/src/relax/op/vision/multibox_transform_loc.cc +++ b/src/relax/op/vision/multibox_transform_loc.cc @@ -57,7 +57,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { StructInfo InferStructInfoMultiboxTransformLoc(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 3) { ctx->ReportFatal(Diagnostic::Error(call) - << "multibox_transform_loc: expected 3 inputs (cls_pred, loc_pred, anchor), got " + << "multibox_transform_loc: expected 3 inputs (cls_pred, loc_pred, anchor), " + "got " << call->args.size()); } @@ -83,7 +84,8 @@ StructInfo InferStructInfoMultiboxTransformLoc(const Call& call, const BlockBuil << anchor_sinfo->ndim); } - if (!cls_sinfo->IsUnknownDtype() && !loc_sinfo->IsUnknownDtype() && cls_sinfo->dtype != loc_sinfo->dtype) { + if (!cls_sinfo->IsUnknownDtype() && !loc_sinfo->IsUnknownDtype() && + cls_sinfo->dtype != loc_sinfo->dtype) { ctx->ReportFatal(Diagnostic::Error(call) << "multibox_transform_loc: cls_pred and loc_pred dtype must match, got " << cls_sinfo->dtype << " vs " << loc_sinfo->dtype); @@ -180,7 +182,8 @@ TVM_REGISTER_OP("relax.vision.multibox_transform_loc") .set_attrs_type() .set_num_inputs(3) .add_argument("cls_pred", "Tensor", "[B,C,N] class logits (pre-softmax).") - .add_argument("loc_pred", "Tensor", "[B,4*N] box encodings per anchor as (x,y,w,h) after yxhw→xywh.") + .add_argument("loc_pred", "Tensor", + "[B,4*N] box encodings (x,y,w,h); TFLite yxhw order remapped to xywh.") .add_argument("anchor", "Tensor", "[1,N,4] priors as ltrb (left,top,right,bottom).") .set_attr("FInferStructInfo", InferStructInfoMultiboxTransformLoc) .set_attr("FPurity", Bool(true)); diff --git a/tests/python/relax/test_op_vision.py b/tests/python/relax/test_op_vision.py index d39977b8c2e3..cded9f5f29e5 100644 --- a/tests/python/relax/test_op_vision.py +++ b/tests/python/relax/test_op_vision.py @@ -286,6 +286,7 @@ def main( ) +@tvm.testing.requires_llvm def test_all_class_non_max_suppression_legalize_e2e(): @tvm.script.ir_module class NMSModule: @@ -387,10 +388,14 @@ def test_multibox_transform_loc_wrong_cls_ndim(): def test_multibox_transform_loc_wrong_shape_relation(): bb = relax.BlockBuilder() cls = relax.Var("cls", R.Tensor((2, 3, 5), "float32")) - loc = relax.Var("loc", R.Tensor((2, 19), "float32")) anc = relax.Var("anc", R.Tensor((1, 5, 4), "float32")) + loc_bad_div = relax.Var("loc_bad_div", R.Tensor((2, 19), "float32")) with pytest.raises(TVMError): - bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc, anc)) + bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc_bad_div, anc)) + # Divisible by 4 but loc_dim != 4*N (N=5 -> expect 20, not 24) + loc_bad_n = relax.Var("loc_bad_n", R.Tensor((2, 24), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc_bad_n, anc)) def test_multibox_transform_loc_wrong_anchor_shape(): @@ -426,7 +431,7 @@ def test_multibox_transform_loc_wrong_batch(): def _multibox_ref_numpy( cls_pred, loc_pred, anchor, variances, clip=False, threshold=0.0, keep_background=True ): - """Minimal numpy reference (avoids importing tvm.topi.testing which pulls scipy).""" + """Numpy reference aligned with ``topi.vision.multibox_transform_loc``.""" def _softmax(x, axis): x_max = np.max(x, axis=axis, keepdims=True) @@ -468,6 +473,7 @@ def _softmax(x, axis): return boxes, scores +@tvm.testing.requires_llvm def test_multibox_transform_loc_legalize_e2e(): @tvm.script.ir_module class Mod: @@ -515,6 +521,55 @@ def main( tvm.testing.assert_allclose(out[1].numpy(), ref_s, rtol=1e-4, atol=1e-5) +@tvm.testing.requires_llvm +def test_multibox_transform_loc_legalize_e2e_nonunity_variances(): + @tvm.script.ir_module + class Mod: + @R.function + def main( + cls: R.Tensor((1, 3, 5), "float32"), + loc: R.Tensor((1, 20), "float32"), + anc: R.Tensor((1, 5, 4), "float32"), + ) -> R.Tuple(R.Tensor((1, 5, 4), "float32"), R.Tensor((1, 3, 5), "float32")): + return R.vision.multibox_transform_loc( + cls, + loc, + anc, + clip=False, + threshold=0.0, + variances=(0.1, 0.1, 0.2, 0.2), + keep_background=True, + ) + + cls_data = np.random.randn(1, 3, 5).astype(np.float32) + loc_data = np.random.randn(1, 20).astype(np.float32) * 0.05 + anc_data = np.array( + [ + [ + [0.1, 0.1, 0.5, 0.5], + [0.2, 0.2, 0.6, 0.6], + [0.0, 0.0, 1.0, 1.0], + [0.3, 0.3, 0.7, 0.7], + [0.05, 0.05, 0.45, 0.45], + ] + ], + dtype=np.float32, + ) + + mod = LegalizeOps()(Mod) + exe = tvm.compile(mod, target="llvm") + vm = relax.VirtualMachine(exe, tvm.cpu()) + ref_b, ref_s = _multibox_ref_numpy(cls_data, loc_data, anc_data, (0.1, 0.1, 0.2, 0.2)) + out = vm["main"]( + tvm.runtime.tensor(cls_data, tvm.cpu()), + tvm.runtime.tensor(loc_data, tvm.cpu()), + tvm.runtime.tensor(anc_data, tvm.cpu()), + ) + tvm.testing.assert_allclose(out[0].numpy(), ref_b, rtol=1e-4, atol=1e-5) + tvm.testing.assert_allclose(out[1].numpy(), ref_s, rtol=1e-4, atol=1e-5) + + +@tvm.testing.requires_llvm def test_multibox_transform_loc_legalize_attr_branches(): @tvm.script.ir_module class Mod: From 1348589e99a74ba4ed1db7f9cd3b0475935022bb Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <792179245@qq.com> Date: Sat, 28 Mar 2026 02:43:37 +0000 Subject: [PATCH 4/4] [Relax] Document multibox_transform_loc infer limits and variances overflow --- include/tvm/relax/attrs/vision.h | 3 ++- python/tvm/relax/op/vision/multibox_transform_loc.py | 8 ++++++++ src/relax/op/vision/multibox_transform_loc.cc | 12 ++++++++++++ 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/include/tvm/relax/attrs/vision.h b/include/tvm/relax/attrs/vision.h index 5189a1c03ec1..4e3351bb90c8 100644 --- a/include/tvm/relax/attrs/vision.h +++ b/include/tvm/relax/attrs/vision.h @@ -88,7 +88,8 @@ struct MultiboxTransformLocAttrs : public AttrsNodeReflAdapter= threshold)``. variances : tuple of 4 floats ``(x,y,w,h)`` = TFLite ``1/x_scale, 1/y_scale, 1/w_scale, 1/h_scale``. + Use magnitudes consistent with the model: very large ``w``/``h`` entries scale the + encoded height/width terms inside ``exp(...)`` and can overflow in float32/float16. keep_background : bool If False, set output scores at class index 0 to zero. @@ -65,6 +67,12 @@ def multibox_transform_loc( - ``N = cls_pred.shape[2]``; ``loc_pred.shape[1] == 4*N``; ``anchor.shape == [1,N,4]``. - ``loc_pred.shape[1]`` must be divisible by 4. - ``cls_pred.shape[0]`` must equal ``loc_pred.shape[0]`` (batch). + + If ``cls_pred`` has **unknown** shape, inference only returns generic rank-3 tensor + struct info for the two outputs; it does **not** verify ``4*N`` vs ``loc_pred`` or + ``anchor.shape[1]`` vs ``N``, because ``N`` is not available statically. Other checks + (ranks, dtypes, ``loc_pred.shape[1] % 4 == 0`` when known, batch match when both batch + axes are known, etc.) still run where applicable. """ return _ffi_api.multibox_transform_loc( cls_pred, diff --git a/src/relax/op/vision/multibox_transform_loc.cc b/src/relax/op/vision/multibox_transform_loc.cc index 75f0a05b85c1..e01e569b78f0 100644 --- a/src/relax/op/vision/multibox_transform_loc.cc +++ b/src/relax/op/vision/multibox_transform_loc.cc @@ -54,6 +54,15 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.vision.multibox_transform_loc", multibox_transform_loc); } +/*! + * \brief Infer struct info for relax.vision.multibox_transform_loc. + * + * \note Shape cross-checks that need the anchor count N (e.g. loc_pred.shape[1] == 4*N, + * anchor.shape[1] == N with N = cls_pred.shape[2]) run only when cls_pred has a known + * static shape. If cls_pred shape is unknown, inference returns generic rank-3 outputs and + * skips those N-based relations; other checks (ndim, dtype, loc dim divisible by 4, etc.) + * still apply when their inputs are known. + */ StructInfo InferStructInfoMultiboxTransformLoc(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 3) { ctx->ReportFatal(Diagnostic::Error(call) @@ -179,6 +188,9 @@ StructInfo InferStructInfoMultiboxTransformLoc(const Call& call, const BlockBuil } TVM_REGISTER_OP("relax.vision.multibox_transform_loc") + .describe("Decode SSD/TFLite-style priors and offsets into boxes and softmax scores. If " + "cls_pred shape is unknown, N-based loc/anchor shape checks are skipped in " + "inference. Very large variances (w,h) can overflow exp in half box sizes.") .set_attrs_type() .set_num_inputs(3) .add_argument("cls_pred", "Tensor", "[B,C,N] class logits (pre-softmax).")