Skip to content

API: use __array_wrap__ to instead of .view(MatrixExpr) #1139

@Zeroto521

Description

@Zeroto521

Is your feature request related to a problem? Please describe.

MatrixVariable is the subclass of MatrixExpr.
The calculation result of MatrixVariable should be MatrixExpr, not MatrixVariable. like #1117

Describe the solution you'd like

NumPy has a protocol __array_wrap__ to fix return type problem.

Additional context

A demo to show that

import numpy as np


class MatrixBase(np.ndarray):
    def __array_wrap__(self, array, context=None, return_scalar=False):
        res = super().__array_wrap__(array, context, return_scalar)
        if return_scalar and isinstance(res, np.ndarray) and res.ndim == 0:
            return res.item()
        elif isinstance(res, np.ndarray):
            if context is not None and context[0] in (
                np.less_equal,
                np.greater_equal,
                np.equal,
            ):
                return res.view(MatrixExprCons)
            return res.view(MatrixExpr)
        return res


class MatrixExpr(MatrixBase): ...


class MatrixVariable(MatrixExpr): ...


class MatrixExprCons(np.ndarray): ...


if __name__ == "__main__":
    a = np.arange(12).reshape((3, 4)).view(MatrixVariable)
    print("Original MatrixExpr:")
    print(type(a))
    print(a)
    # Original MatrixExpr:
    # <class '__main__.MatrixVariable'>
    # [[ 0  1  2  3]
    #  [ 4  5  6  7]
    #  [ 8  9 10 11]]

    b = a + 10
    print("\nAfter addition:")
    # it still is MatrixExpr not MatrixVariable, without changing to .add's default behavior
    # because of __array_ufunc__
    print(type(b))
    print(b)
    # After addition:
    # <class '__main__.MatrixExpr'>
    # [[10 11 12 13]
    #  [14 15 16 17]
    #  [18 19 20 21]]

    c = a.sum(axis=1)
    print("\nAfter sum:")
    # it is still MatrixExpr, without changing to .sum's default behavior
    print(type(c))
    print(c)
    # After sum:
    # <class '__main__.MatrixExpr'>
    # [ 6 22 38]

    d = a == b
    print("\nAfter comparison (a == b):")
    # it is MatrixExprCons because of __array_wrap__
    print(type(d))
    print(d)
    # After comparison (a == b):
    # <class '__main__.MatrixExprCons'>
    # [[False False False False]
    #  [False False False False]
    #  [False False False False]]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions