diff --git a/include/tvm/relax/attrs/vision.h b/include/tvm/relax/attrs/vision.h index 59a1dd7314fc..4e3351bb90c8 100644 --- a/include/tvm/relax/attrs/vision.h +++ b/include/tvm/relax/attrs/vision.h @@ -73,6 +73,30 @@ struct ROIAlignAttrs : public AttrsNodeReflAdapter { TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIAlignAttrs", ROIAlignAttrs, BaseAttrsNode); }; // struct ROIAlignAttrs +/*! \brief Attributes for multibox_transform_loc (SSD / TFLite-style box decode). */ +struct MultiboxTransformLocAttrs : public AttrsNodeReflAdapter { + bool clip; + double threshold; + ffi::Array variances; + bool keep_background; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("clip", &MultiboxTransformLocAttrs::clip, + "Clip decoded ymin,xmin,ymax,xmax to [0,1].") + .def_ro("threshold", &MultiboxTransformLocAttrs::threshold, + "After softmax, zero scores strictly below this value.") + .def_ro("variances", &MultiboxTransformLocAttrs::variances, + "(x,y,w,h) scales = TFLite 1/x_scale,1/y_scale,1/w_scale,1/h_scale on " + "encodings. Very large w/h scales can overflow exp in decode.") + .def_ro("keep_background", &MultiboxTransformLocAttrs::keep_background, + "If false, force output scores[:,0,:] to 0 (background class)."); + } + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.MultiboxTransformLocAttrs", + MultiboxTransformLocAttrs, BaseAttrsNode); +}; // struct MultiboxTransformLocAttrs + } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 0abd700562e2..60cbc3178aee 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -3205,9 +3205,10 @@ def convert_dequantize(self, op): def convert_detection_postprocess(self, op): """Convert TFLite_Detection_PostProcess""" raise NotImplementedError( - "DETECTION_POSTPROCESS requires vision ops (multibox_transform_loc, " - "non_max_suppression, get_valid_counts) not yet available in Relax. " - "See https://github.com/apache/tvm/issues/XXXX" + "DETECTION_POSTPROCESS is not wired in this frontend yet: it still needs " + "Relax NMS / get_valid_counts / related vision helpers (see dead code below). " + "relax.vision.multibox_transform_loc exists; tracking: " + "https://github.com/apache/tvm/issues/18928" ) flexbuffer = op.CustomOptionsAsNumpy().tobytes() custom_options = FlexBufferDecoder(flexbuffer).decode() @@ -3340,9 +3341,8 @@ def convert_nms_v5(self, op): """Convert TFLite NonMaxSuppressionV5""" # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/non-max-suppression-v5 raise NotImplementedError( - "NON_MAX_SUPPRESSION_V5 requires vision ops (get_valid_counts, " - "non_max_suppression) not yet available in Relax. " - "See https://github.com/apache/tvm/issues/XXXX" + "NON_MAX_SUPPRESSION_V5 is not wired in this frontend yet (needs get_valid_counts, " + "non_max_suppression, etc.). Tracking: https://github.com/apache/tvm/issues/18928" ) input_tensors = self.get_input_tensors(op) diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 0bc3f6578432..ee1a2c24206e 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -157,7 +157,7 @@ tanh, trunc, ) -from .vision import all_class_non_max_suppression, roi_align +from .vision import all_class_non_max_suppression, multibox_transform_loc, roi_align def _register_op_make(): diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index a3b6544dcc6e..e8c91f04b459 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -251,6 +251,11 @@ class ROIAlignAttrs(Attrs): """Attributes for vision.roi_align""" +@tvm_ffi.register_object("relax.attrs.MultiboxTransformLocAttrs") +class MultiboxTransformLocAttrs(Attrs): + """Attributes for vision.multibox_transform_loc""" + + @tvm_ffi.register_object("relax.attrs.Conv1DAttrs") class Conv1DAttrs(Attrs): """Attributes for nn.conv1d""" diff --git a/python/tvm/relax/op/vision/__init__.py b/python/tvm/relax/op/vision/__init__.py index 76d9ea35a11c..58266c5b2add 100644 --- a/python/tvm/relax/op/vision/__init__.py +++ b/python/tvm/relax/op/vision/__init__.py @@ -17,5 +17,6 @@ # under the License. """VISION operators.""" +from .multibox_transform_loc import * from .nms import * from .roi_align import * diff --git a/python/tvm/relax/op/vision/multibox_transform_loc.py b/python/tvm/relax/op/vision/multibox_transform_loc.py new file mode 100644 index 000000000000..6830b1dc6321 --- /dev/null +++ b/python/tvm/relax/op/vision/multibox_transform_loc.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Multibox location transform for object detection.""" + +from . import _ffi_api + + +def multibox_transform_loc( + cls_pred, + loc_pred, + anchor, + clip=False, + threshold=0.0, + variances=(1.0, 1.0, 1.0, 1.0), + keep_background=True, +): + """SSD / TFLite-style decode: priors + offsets → boxes; logits → softmax scores. + + Box decode follows TFLite ``DecodeCenterSizeBoxes``; expected tensor layout matches + ``tflite_frontend.convert_detection_postprocess`` (loc reorder yxhw→xywh, anchor ltrb). + + Parameters + ---------- + cls_pred : relax.Expr + ``[B, C, N]`` class logits (pre-softmax). + loc_pred : relax.Expr + ``[B, 4*N]`` per-anchor encodings as ``(x,y,w,h)`` after reorder (see above). + anchor : relax.Expr + ``[1, N, 4]`` priors: ``(left, top, right, bottom)``. + clip : bool + If True, clip ``ymin,xmin,ymax,xmax`` to ``[0, 1]``. + threshold : float + After softmax, multiply scores by mask ``(score >= threshold)``. + variances : tuple of 4 floats + ``(x,y,w,h)`` = TFLite ``1/x_scale, 1/y_scale, 1/w_scale, 1/h_scale``. + Use magnitudes consistent with the model: very large ``w``/``h`` entries scale the + encoded height/width terms inside ``exp(...)`` and can overflow in float32/float16. + keep_background : bool + If False, set output scores at class index 0 to zero. + + Returns + ------- + result : relax.Expr + Tuple ``(boxes, scores)``: ``boxes`` is ``[B, N, 4]`` as ``(ymin,xmin,ymax,xmax)``; + ``scores`` is ``[B, C, N]`` softmax, post-processed like the implementation. + + Notes + ----- + **Shape/dtype (checked in ``FInferStructInfo`` when static):** + + - ``cls_pred``: 3-D; ``loc_pred``: 2-D; ``anchor``: 3-D. + - ``cls_pred``, ``loc_pred``, ``anchor`` dtypes must match. + - ``N = cls_pred.shape[2]``; ``loc_pred.shape[1] == 4*N``; ``anchor.shape == [1,N,4]``. + - ``loc_pred.shape[1]`` must be divisible by 4. + - ``cls_pred.shape[0]`` must equal ``loc_pred.shape[0]`` (batch). + + If ``cls_pred`` has **unknown** shape, inference only returns generic rank-3 tensor + struct info for the two outputs; it does **not** verify ``4*N`` vs ``loc_pred`` or + ``anchor.shape[1]`` vs ``N``, because ``N`` is not available statically. Other checks + (ranks, dtypes, ``loc_pred.shape[1] % 4 == 0`` when known, batch match when both batch + axes are known, etc.) still run where applicable. + """ + return _ffi_api.multibox_transform_loc( + cls_pred, + loc_pred, + anchor, + clip, + threshold, + variances, + keep_background, + ) diff --git a/python/tvm/relax/transform/legalize_ops/vision.py b/python/tvm/relax/transform/legalize_ops/vision.py index 7a1e305f39f0..28367a67a361 100644 --- a/python/tvm/relax/transform/legalize_ops/vision.py +++ b/python/tvm/relax/transform/legalize_ops/vision.py @@ -118,3 +118,27 @@ def _roi_align(bb: BlockBuilder, call: Call) -> Expr: aligned=call.attrs.aligned, layout=call.attrs.layout, ) + + +@register_legalize("relax.vision.multibox_transform_loc") +def _multibox_transform_loc(bb: BlockBuilder, call: Call) -> Expr: + variances = tuple(float(x) for x in call.attrs.variances) + + def _te(cls_pred, loc_pred, anchor): + return topi.vision.multibox_transform_loc( + cls_pred, + loc_pred, + anchor, + variances, + clip=call.attrs.clip, + threshold=call.attrs.threshold, + keep_background=call.attrs.keep_background, + ) + + return bb.call_te( + _te, + call.args[0], + call.args[1], + call.args[2], + primfunc_name_hint="multibox_transform_loc", + ) diff --git a/python/tvm/topi/vision/__init__.py b/python/tvm/topi/vision/__init__.py index 75725a8a4bea..cb0467c98cd4 100644 --- a/python/tvm/topi/vision/__init__.py +++ b/python/tvm/topi/vision/__init__.py @@ -17,5 +17,6 @@ # under the License. """Vision operators.""" +from .multibox_transform_loc import * from .nms import * from .roi_align import * diff --git a/python/tvm/topi/vision/multibox_transform_loc.py b/python/tvm/topi/vision/multibox_transform_loc.py new file mode 100644 index 000000000000..ab965e798141 --- /dev/null +++ b/python/tvm/topi/vision/multibox_transform_loc.py @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Multibox location transform (SSD / TFLite DetectionPostProcess decode).""" + +import tvm +from tvm import te, topi + + +def multibox_transform_loc( + cls_pred, + loc_pred, + anchor, + variances, + clip=False, + threshold=0.0, + keep_background=True, +): + """TFLite ``DecodeCenterSizeBoxes``-style decode + softmax score post-process. + + Inputs must match Relax op contracts: ``cls_pred [B,C,N]``, ``loc_pred [B,4*N]``, + ``anchor [1,N,4]`` ltrb; per-anchor loc order ``(x,y,w,h)`` after yxhw→xywh reorder. + + Parameters + ---------- + cls_pred : te.Tensor + ``[B, C, N]`` logits. + loc_pred : te.Tensor + ``[B, 4*N]`` encodings ``(x,y,w,h)`` per anchor. + anchor : te.Tensor + ``[1, N, 4]`` ``(left, top, right, bottom)``. + variances : tuple of 4 float + ``(x,y,w,h)`` = ``1/x_scale, 1/y_scale, 1/w_scale, 1/h_scale`` (TFLite). + clip : bool + Clip ``ymin,xmin,ymax,xmax`` to ``[0,1]``. + threshold : float + After softmax: ``scores *= (scores >= threshold)``. + keep_background : bool + If False: ``scores[:,0,:] = 0``. + + Returns + ------- + boxes : te.Tensor + ``[B, N, 4]`` as ``(ymin,xmin,ymax,xmax)``. + scores : te.Tensor + ``[B, C, N]`` softmax, then threshold mask and optional background zero. + """ + dtype = cls_pred.dtype + B = cls_pred.shape[0] + num_anchors = cls_pred.shape[2] + loc_reshaped = topi.reshape(loc_pred, [B, num_anchors, 4]) + + vx = tvm.tirx.const(float(variances[0]), dtype) + vy = tvm.tirx.const(float(variances[1]), dtype) + vw = tvm.tirx.const(float(variances[2]), dtype) + vh = tvm.tirx.const(float(variances[3]), dtype) + half = tvm.tirx.const(0.5, dtype) + zero = tvm.tirx.const(0.0, dtype) + one = tvm.tirx.const(1.0, dtype) + th = tvm.tirx.const(float(threshold), dtype) + + def decode_bbox(b, a, k): + l = anchor[0, a, 0] + t = anchor[0, a, 1] + r = anchor[0, a, 2] + br = anchor[0, a, 3] + ay = (t + br) * half + ax = (l + r) * half + ah = br - t + aw = r - l + ex = loc_reshaped[b, a, 0] + ey = loc_reshaped[b, a, 1] + ew = loc_reshaped[b, a, 2] + eh = loc_reshaped[b, a, 3] + ycenter = ey * vy * ah + ay + xcenter = ex * vx * aw + ax + half_h = half * te.exp(eh * vh) * ah + half_w = half * te.exp(ew * vw) * aw + ymin = ycenter - half_h + xmin = xcenter - half_w + ymax = ycenter + half_h + xmax = xcenter + half_w + if clip: + ymin = te.max(zero, te.min(one, ymin)) + xmin = te.max(zero, te.min(one, xmin)) + ymax = te.max(zero, te.min(one, ymax)) + xmax = te.max(zero, te.min(one, xmax)) + return tvm.tirx.Select( + k == 0, + ymin, + tvm.tirx.Select(k == 1, xmin, tvm.tirx.Select(k == 2, ymax, xmax)), + ) + + boxes = te.compute((B, num_anchors, 4), decode_bbox, name="multibox_boxes") + + scores = topi.nn.softmax(cls_pred, axis=1) + mask = topi.cast(topi.greater_equal(scores, th), dtype) + scores = scores * mask + if not keep_background: + + def zero_bg(b, c, n): + s = scores[b, c, n] + return te.if_then_else(c == 0, zero, s) + + scores = te.compute(scores.shape, zero_bg, name="multibox_scores_bg") + + return [boxes, scores] diff --git a/src/relax/op/vision/multibox_transform_loc.cc b/src/relax/op/vision/multibox_transform_loc.cc new file mode 100644 index 000000000000..e01e569b78f0 --- /dev/null +++ b/src/relax/op/vision/multibox_transform_loc.cc @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file multibox_transform_loc.cc + * \brief Multibox transform (location decode) for object detection. + */ + +#include "multibox_transform_loc.h" + +#include +#include + +#include + +namespace tvm { +namespace relax { + +TVM_FFI_STATIC_INIT_BLOCK() { MultiboxTransformLocAttrs::RegisterReflection(); } + +Expr multibox_transform_loc(Expr cls_pred, Expr loc_pred, Expr anchor, bool clip, double threshold, + ffi::Array variances, bool keep_background) { + TVM_FFI_ICHECK_EQ(variances.size(), 4) + << "multibox_transform_loc: variances must be length 4 (x,y,w,h), got " << variances.size(); + + auto attrs = ffi::make_object(); + attrs->clip = clip; + attrs->threshold = threshold; + attrs->variances = std::move(variances); + attrs->keep_background = keep_background; + + static const Op& op = Op::Get("relax.vision.multibox_transform_loc"); + return Call(op, {std::move(cls_pred), std::move(loc_pred), std::move(anchor)}, Attrs(attrs), {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.vision.multibox_transform_loc", multibox_transform_loc); +} + +/*! + * \brief Infer struct info for relax.vision.multibox_transform_loc. + * + * \note Shape cross-checks that need the anchor count N (e.g. loc_pred.shape[1] == 4*N, + * anchor.shape[1] == N with N = cls_pred.shape[2]) run only when cls_pred has a known + * static shape. If cls_pred shape is unknown, inference returns generic rank-3 outputs and + * skips those N-based relations; other checks (ndim, dtype, loc dim divisible by 4, etc.) + * still apply when their inputs are known. + */ +StructInfo InferStructInfoMultiboxTransformLoc(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 3) { + ctx->ReportFatal(Diagnostic::Error(call) + << "multibox_transform_loc: expected 3 inputs (cls_pred, loc_pred, anchor), " + "got " + << call->args.size()); + } + + ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); + const auto cls_sinfo = input_sinfo[0]; + const auto loc_sinfo = input_sinfo[1]; + const auto anchor_sinfo = input_sinfo[2]; + + if (!cls_sinfo->IsUnknownNdim() && cls_sinfo->ndim != 3) { + ctx->ReportFatal(Diagnostic::Error(call) + << "multibox_transform_loc: cls_pred must be 3-D [B, num_classes, N], got " + "ndim " + << cls_sinfo->ndim); + } + if (!loc_sinfo->IsUnknownNdim() && loc_sinfo->ndim != 2) { + ctx->ReportFatal(Diagnostic::Error(call) + << "multibox_transform_loc: loc_pred must be 2-D [B, 4*N], got ndim " + << loc_sinfo->ndim); + } + if (!anchor_sinfo->IsUnknownNdim() && anchor_sinfo->ndim != 3) { + ctx->ReportFatal(Diagnostic::Error(call) + << "multibox_transform_loc: anchor must be 3-D [1, N, 4] ltrb, got ndim " + << anchor_sinfo->ndim); + } + + if (!cls_sinfo->IsUnknownDtype() && !loc_sinfo->IsUnknownDtype() && + cls_sinfo->dtype != loc_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "multibox_transform_loc: cls_pred and loc_pred dtype must match, got " + << cls_sinfo->dtype << " vs " << loc_sinfo->dtype); + } + if (!cls_sinfo->IsUnknownDtype() && !anchor_sinfo->IsUnknownDtype() && + cls_sinfo->dtype != anchor_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "multibox_transform_loc: cls_pred and anchor dtype must match, got " + << cls_sinfo->dtype << " vs " << anchor_sinfo->dtype); + } + + auto vdev = cls_sinfo->vdevice; + const auto* cls_shape = cls_sinfo->shape.as(); + const auto* loc_shape = loc_sinfo->shape.as(); + const auto* anchor_shape = anchor_sinfo->shape.as(); + + if (loc_shape != nullptr) { + const auto* loc_dim1 = loc_shape->values[1].as(); + if (loc_dim1 != nullptr && loc_dim1->value % 4 != 0) { + ctx->ReportFatal(Diagnostic::Error(call) + << "multibox_transform_loc: loc_pred.shape[1] must be divisible by 4, got " + << loc_dim1->value); + } + } + + if (cls_shape != nullptr && loc_shape != nullptr) { + const auto* cls_b = cls_shape->values[0].as(); + const auto* loc_b = loc_shape->values[0].as(); + if (cls_b != nullptr && loc_b != nullptr && cls_b->value != loc_b->value) { + ctx->ReportFatal(Diagnostic::Error(call) + << "multibox_transform_loc: cls_pred.shape[0] must match loc_pred.shape[0], " + "got B=" + << cls_b->value << " vs " << loc_b->value); + } + } + + if (anchor_shape != nullptr) { + const auto* anchor_batch = anchor_shape->values[0].as(); + if (anchor_batch != nullptr && anchor_batch->value != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "multibox_transform_loc: anchor.shape[0] must be 1, got " + << anchor_batch->value); + } + const auto* anchor_last = anchor_shape->values[2].as(); + if (anchor_last != nullptr && anchor_last->value != 4) { + ctx->ReportFatal(Diagnostic::Error(call) + << "multibox_transform_loc: anchor.shape[2] must be 4 (ltrb), got " + << anchor_last->value); + } + } + + if (cls_shape == nullptr) { + ffi::Array fields = {TensorStructInfo(cls_sinfo->dtype, 3, vdev), + TensorStructInfo(cls_sinfo->dtype, 3, vdev)}; + return TupleStructInfo(fields); + } + + const auto& batch = cls_shape->values[0]; + const auto& num_classes = cls_shape->values[1]; + const auto& num_anchors = cls_shape->values[2]; + + if (loc_shape != nullptr) { + const auto* num_anchors_imm = num_anchors.as(); + const auto* loc_dim1 = loc_shape->values[1].as(); + if (num_anchors_imm != nullptr && loc_dim1 != nullptr && + loc_dim1->value != num_anchors_imm->value * 4) { + ctx->ReportFatal(Diagnostic::Error(call) + << "multibox_transform_loc: loc_pred.shape[1] must equal 4*N with " + "N=cls_pred.shape[2], got loc_dim=" + << loc_dim1->value << ", N=" << num_anchors_imm->value); + } + } + if (anchor_shape != nullptr) { + const auto* num_anchors_imm = num_anchors.as(); + const auto* anchor_num_anchors = anchor_shape->values[1].as(); + if (num_anchors_imm != nullptr && anchor_num_anchors != nullptr && + anchor_num_anchors->value != num_anchors_imm->value) { + ctx->ReportFatal(Diagnostic::Error(call) + << "multibox_transform_loc: anchor.shape[1] must equal N=cls_pred.shape[2], " + "got anchor_N=" + << anchor_num_anchors->value << ", N=" << num_anchors_imm->value); + } + } + + ffi::Array boxes_shape = {batch, num_anchors, Integer(4)}; + ffi::Array scores_shape = {batch, num_classes, num_anchors}; + ffi::Array fields = { + TensorStructInfo(ShapeExpr(boxes_shape), cls_sinfo->dtype, vdev), + TensorStructInfo(ShapeExpr(scores_shape), cls_sinfo->dtype, vdev)}; + return TupleStructInfo(fields); +} + +TVM_REGISTER_OP("relax.vision.multibox_transform_loc") + .describe("Decode SSD/TFLite-style priors and offsets into boxes and softmax scores. If " + "cls_pred shape is unknown, N-based loc/anchor shape checks are skipped in " + "inference. Very large variances (w,h) can overflow exp in half box sizes.") + .set_attrs_type() + .set_num_inputs(3) + .add_argument("cls_pred", "Tensor", "[B,C,N] class logits (pre-softmax).") + .add_argument("loc_pred", "Tensor", + "[B,4*N] box encodings (x,y,w,h); TFLite yxhw order remapped to xywh.") + .add_argument("anchor", "Tensor", "[1,N,4] priors as ltrb (left,top,right,bottom).") + .set_attr("FInferStructInfo", InferStructInfoMultiboxTransformLoc) + .set_attr("FPurity", Bool(true)); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/vision/multibox_transform_loc.h b/src/relax/op/vision/multibox_transform_loc.h new file mode 100644 index 000000000000..726bc4c0e582 --- /dev/null +++ b/src/relax/op/vision/multibox_transform_loc.h @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file multibox_transform_loc.h + * \brief The functions to make Relax multibox_transform_loc operator calls. + */ + +#ifndef TVM_RELAX_OP_VISION_MULTIBOX_TRANSFORM_LOC_H_ +#define TVM_RELAX_OP_VISION_MULTIBOX_TRANSFORM_LOC_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! \brief Decode SSD box encodings and prepare class scores (TFLite-compatible). */ +Expr multibox_transform_loc(Expr cls_pred, Expr loc_pred, Expr anchor, bool clip, double threshold, + ffi::Array variances, bool keep_background); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_VISION_MULTIBOX_TRANSFORM_LOC_H_ diff --git a/tests/python/relax/test_op_vision.py b/tests/python/relax/test_op_vision.py index b902518b49bb..cded9f5f29e5 100644 --- a/tests/python/relax/test_op_vision.py +++ b/tests/python/relax/test_op_vision.py @@ -286,6 +286,7 @@ def main( ) +@tvm.testing.requires_llvm def test_all_class_non_max_suppression_legalize_e2e(): @tvm.script.ir_module class NMSModule: @@ -344,5 +345,287 @@ def main( tvm.testing.assert_allclose(selected_indices.shape, (num_total_detections, 3)) +def test_multibox_transform_loc_op_correctness(): + cls = relax.Var("cls", R.Tensor((1, 5, 10), "float32")) + loc = relax.Var("loc", R.Tensor((1, 40), "float32")) + anc = relax.Var("anc", R.Tensor((1, 10, 4), "float32")) + assert ( + relax.op.vision.multibox_transform_loc( + cls, loc, anc, False, 0.0, (1.0, 1.0, 1.0, 1.0), True + ).op + == Op.get("relax.vision.multibox_transform_loc") + ) + + +def test_multibox_transform_loc_infer_struct_info(): + bb = relax.BlockBuilder() + cls = relax.Var("cls", R.Tensor((2, 3, 5), "float32")) + loc = relax.Var("loc", R.Tensor((2, 20), "float32")) + anc = relax.Var("anc", R.Tensor((1, 5, 4), "float32")) + _check_inference( + bb, + relax.op.vision.multibox_transform_loc( + cls, loc, anc, False, 0.0, (0.1, 0.1, 0.2, 0.2), True + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 5, 4), "float32"), + relax.TensorStructInfo((2, 3, 5), "float32"), + ] + ), + ) + + +def test_multibox_transform_loc_wrong_cls_ndim(): + bb = relax.BlockBuilder() + cls = relax.Var("cls", R.Tensor((2, 3), "float32")) + loc = relax.Var("loc", R.Tensor((2, 20), "float32")) + anc = relax.Var("anc", R.Tensor((1, 5, 4), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc, anc)) + + +def test_multibox_transform_loc_wrong_shape_relation(): + bb = relax.BlockBuilder() + cls = relax.Var("cls", R.Tensor((2, 3, 5), "float32")) + anc = relax.Var("anc", R.Tensor((1, 5, 4), "float32")) + loc_bad_div = relax.Var("loc_bad_div", R.Tensor((2, 19), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc_bad_div, anc)) + # Divisible by 4 but loc_dim != 4*N (N=5 -> expect 20, not 24) + loc_bad_n = relax.Var("loc_bad_n", R.Tensor((2, 24), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc_bad_n, anc)) + + +def test_multibox_transform_loc_wrong_anchor_shape(): + bb = relax.BlockBuilder() + cls = relax.Var("cls", R.Tensor((2, 3, 5), "float32")) + loc = relax.Var("loc", R.Tensor((2, 20), "float32")) + anc_bad_batch = relax.Var("anc_bad_batch", R.Tensor((2, 5, 4), "float32")) + anc_bad_last = relax.Var("anc_bad_last", R.Tensor((1, 5, 5), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc, anc_bad_batch)) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc, anc_bad_last)) + + +def test_multibox_transform_loc_wrong_dtype(): + bb = relax.BlockBuilder() + cls = relax.Var("cls", R.Tensor((2, 3, 5), "float32")) + loc = relax.Var("loc", R.Tensor((2, 20), "float16")) + anc = relax.Var("anc", R.Tensor((1, 5, 4), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc, anc)) + + +def test_multibox_transform_loc_wrong_batch(): + bb = relax.BlockBuilder() + cls = relax.Var("cls", R.Tensor((2, 3, 5), "float32")) + loc = relax.Var("loc", R.Tensor((1, 20), "float32")) + anc = relax.Var("anc", R.Tensor((1, 5, 4), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc, anc)) + + +def _multibox_ref_numpy( + cls_pred, loc_pred, anchor, variances, clip=False, threshold=0.0, keep_background=True +): + """Numpy reference aligned with ``topi.vision.multibox_transform_loc``.""" + + def _softmax(x, axis): + x_max = np.max(x, axis=axis, keepdims=True) + exp = np.exp(x - x_max) + return exp / np.sum(exp, axis=axis, keepdims=True) + + B, C, N = cls_pred.shape + loc = loc_pred.reshape(B, N, 4) + scores = _softmax(cls_pred.astype("float64"), axis=1).astype(np.float32) + if threshold > 0.0: + scores = np.where(scores >= threshold, scores, 0.0).astype(np.float32) + if not keep_background: + scores = scores.copy() + scores[:, 0, :] = 0.0 + vx, vy, vw, vh = variances + boxes = np.zeros((B, N, 4), dtype=np.float32) + for b in range(B): + for a in range(N): + l, t, r, br = anchor[0, a, :] + ay = (t + br) * 0.5 + ax = (l + r) * 0.5 + ah = br - t + aw = r - l + ex, ey, ew, eh = loc[b, a, :] + ycenter = ey * vy * ah + ay + xcenter = ex * vx * aw + ax + half_h = 0.5 * np.exp(eh * vh) * ah + half_w = 0.5 * np.exp(ew * vw) * aw + ymin = ycenter - half_h + xmin = xcenter - half_w + ymax = ycenter + half_h + xmax = xcenter + half_w + if clip: + ymin = np.clip(ymin, 0.0, 1.0) + xmin = np.clip(xmin, 0.0, 1.0) + ymax = np.clip(ymax, 0.0, 1.0) + xmax = np.clip(xmax, 0.0, 1.0) + boxes[b, a, :] = (ymin, xmin, ymax, xmax) + return boxes, scores + + +@tvm.testing.requires_llvm +def test_multibox_transform_loc_legalize_e2e(): + @tvm.script.ir_module + class Mod: + @R.function + def main( + cls: R.Tensor((1, 3, 5), "float32"), + loc: R.Tensor((1, 20), "float32"), + anc: R.Tensor((1, 5, 4), "float32"), + ) -> R.Tuple(R.Tensor((1, 5, 4), "float32"), R.Tensor((1, 3, 5), "float32")): + return R.vision.multibox_transform_loc( + cls, + loc, + anc, + clip=False, + threshold=0.0, + variances=(1.0, 1.0, 1.0, 1.0), + keep_background=True, + ) + + cls_data = np.random.randn(1, 3, 5).astype(np.float32) + loc_data = np.random.randn(1, 20).astype(np.float32) * 0.05 + anc_data = np.array( + [ + [ + [0.1, 0.1, 0.5, 0.5], + [0.2, 0.2, 0.6, 0.6], + [0.0, 0.0, 1.0, 1.0], + [0.3, 0.3, 0.7, 0.7], + [0.05, 0.05, 0.45, 0.45], + ] + ], + dtype=np.float32, + ) + + mod = LegalizeOps()(Mod) + exe = tvm.compile(mod, target="llvm") + vm = relax.VirtualMachine(exe, tvm.cpu()) + ref_b, ref_s = _multibox_ref_numpy(cls_data, loc_data, anc_data, (1.0, 1.0, 1.0, 1.0)) + out = vm["main"]( + tvm.runtime.tensor(cls_data, tvm.cpu()), + tvm.runtime.tensor(loc_data, tvm.cpu()), + tvm.runtime.tensor(anc_data, tvm.cpu()), + ) + tvm.testing.assert_allclose(out[0].numpy(), ref_b, rtol=1e-4, atol=1e-5) + tvm.testing.assert_allclose(out[1].numpy(), ref_s, rtol=1e-4, atol=1e-5) + + +@tvm.testing.requires_llvm +def test_multibox_transform_loc_legalize_e2e_nonunity_variances(): + @tvm.script.ir_module + class Mod: + @R.function + def main( + cls: R.Tensor((1, 3, 5), "float32"), + loc: R.Tensor((1, 20), "float32"), + anc: R.Tensor((1, 5, 4), "float32"), + ) -> R.Tuple(R.Tensor((1, 5, 4), "float32"), R.Tensor((1, 3, 5), "float32")): + return R.vision.multibox_transform_loc( + cls, + loc, + anc, + clip=False, + threshold=0.0, + variances=(0.1, 0.1, 0.2, 0.2), + keep_background=True, + ) + + cls_data = np.random.randn(1, 3, 5).astype(np.float32) + loc_data = np.random.randn(1, 20).astype(np.float32) * 0.05 + anc_data = np.array( + [ + [ + [0.1, 0.1, 0.5, 0.5], + [0.2, 0.2, 0.6, 0.6], + [0.0, 0.0, 1.0, 1.0], + [0.3, 0.3, 0.7, 0.7], + [0.05, 0.05, 0.45, 0.45], + ] + ], + dtype=np.float32, + ) + + mod = LegalizeOps()(Mod) + exe = tvm.compile(mod, target="llvm") + vm = relax.VirtualMachine(exe, tvm.cpu()) + ref_b, ref_s = _multibox_ref_numpy(cls_data, loc_data, anc_data, (0.1, 0.1, 0.2, 0.2)) + out = vm["main"]( + tvm.runtime.tensor(cls_data, tvm.cpu()), + tvm.runtime.tensor(loc_data, tvm.cpu()), + tvm.runtime.tensor(anc_data, tvm.cpu()), + ) + tvm.testing.assert_allclose(out[0].numpy(), ref_b, rtol=1e-4, atol=1e-5) + tvm.testing.assert_allclose(out[1].numpy(), ref_s, rtol=1e-4, atol=1e-5) + + +@tvm.testing.requires_llvm +def test_multibox_transform_loc_legalize_attr_branches(): + @tvm.script.ir_module + class Mod: + @R.function + def main( + cls: R.Tensor((1, 3, 4), "float32"), + loc: R.Tensor((1, 16), "float32"), + anc: R.Tensor((1, 4, 4), "float32"), + ) -> R.Tuple(R.Tensor((1, 4, 4), "float32"), R.Tensor((1, 3, 4), "float32")): + return R.vision.multibox_transform_loc( + cls, + loc, + anc, + clip=True, + threshold=0.4, + variances=(1.0, 1.0, 1.0, 1.0), + keep_background=False, + ) + + cls_data = np.array( + [[[2.0, 0.1, -0.5, 0.0], [0.2, 2.2, 0.3, -1.0], [0.1, 0.4, 2.0, 0.5]]], + dtype=np.float32, + ) + loc_data = np.array( + [[0.1, -0.2, 0.0, 0.0, -0.2, 0.1, 0.3, -0.1, 0.0, 0.0, 0.8, 0.8, 0.2, 0.2, -0.6, -0.6]], + dtype=np.float32, + ) + anc_data = np.array( + [[[0.1, 0.1, 0.5, 0.5], [0.2, 0.2, 0.6, 0.6], [0.0, 0.0, 1.0, 1.0], [0.4, 0.4, 1.2, 1.2]]], + dtype=np.float32, + ) + + mod = LegalizeOps()(Mod) + exe = tvm.compile(mod, target="llvm") + vm = relax.VirtualMachine(exe, tvm.cpu()) + ref_b, ref_s = _multibox_ref_numpy( + cls_data, + loc_data, + anc_data, + (1.0, 1.0, 1.0, 1.0), + clip=True, + threshold=0.4, + keep_background=False, + ) + out = vm["main"]( + tvm.runtime.tensor(cls_data, tvm.cpu()), + tvm.runtime.tensor(loc_data, tvm.cpu()), + tvm.runtime.tensor(anc_data, tvm.cpu()), + ) + boxes = out[0].numpy() + scores = out[1].numpy() + tvm.testing.assert_allclose(boxes, ref_b, rtol=1e-4, atol=1e-5) + tvm.testing.assert_allclose(scores, ref_s, rtol=1e-4, atol=1e-5) + assert np.all(boxes >= 0.0) and np.all(boxes <= 1.0) + tvm.testing.assert_allclose(scores[:, 0, :], np.zeros_like(scores[:, 0, :])) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_vision.py b/tests/python/relax/test_tvmscript_parser_op_vision.py index c4e8ff0c9d22..f053e3674493 100644 --- a/tests/python/relax/test_tvmscript_parser_op_vision.py +++ b/tests/python/relax/test_tvmscript_parser_op_vision.py @@ -75,6 +75,48 @@ def foo( _check(foo, bb.get()["foo"]) +def test_multibox_transform_loc(): + @R.function + def foo( + cls: R.Tensor((1, 3, 5), "float32"), + loc: R.Tensor((1, 20), "float32"), + anc: R.Tensor((1, 5, 4), "float32"), + ) -> R.Tuple(R.Tensor((1, 5, 4), "float32"), R.Tensor((1, 3, 5), "float32")): + gv: R.Tuple(R.Tensor((1, 5, 4), "float32"), R.Tensor((1, 3, 5), "float32")) = ( + R.vision.multibox_transform_loc( + cls, + loc, + anc, + clip=False, + threshold=0.0, + variances=(1.0, 1.0, 1.0, 1.0), + keep_background=True, + ) + ) + return gv + + cls = relax.Var("cls", R.Tensor((1, 3, 5), "float32")) + loc = relax.Var("loc", R.Tensor((1, 20), "float32")) + anc = relax.Var("anc", R.Tensor((1, 5, 4), "float32")) + + bb = relax.BlockBuilder() + with bb.function("foo", [cls, loc, anc]): + gv = bb.emit( + relax.op.vision.multibox_transform_loc( + cls, + loc, + anc, + clip=False, + threshold=0.0, + variances=(1.0, 1.0, 1.0, 1.0), + keep_background=True, + ) + ) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + def test_roi_align(): @R.function def foo(