-
Notifications
You must be signed in to change notification settings - Fork 274
Open
Description
Is your feature request related to a problem? Please describe.
MatrixVariable is the subclass of MatrixExpr.
The calculation result of MatrixVariable should be MatrixExpr, not MatrixVariable. like #1117
Describe the solution you'd like
NumPy has a protocol __array_wrap__ to fix return type problem.
Additional context
A demo to show that
import numpy as np
class MatrixBase(np.ndarray):
def __array_wrap__(self, array, context=None, return_scalar=False):
res = super().__array_wrap__(array, context, return_scalar)
if return_scalar and isinstance(res, np.ndarray) and res.ndim == 0:
return res.item()
elif isinstance(res, np.ndarray):
if context is not None and context[0] in (
np.less_equal,
np.greater_equal,
np.equal,
):
return res.view(MatrixExprCons)
return res.view(MatrixExpr)
return res
class MatrixExpr(MatrixBase): ...
class MatrixVariable(MatrixExpr): ...
class MatrixExprCons(np.ndarray): ...
if __name__ == "__main__":
a = np.arange(12).reshape((3, 4)).view(MatrixVariable)
print("Original MatrixExpr:")
print(type(a))
print(a)
# Original MatrixExpr:
# <class '__main__.MatrixVariable'>
# [[ 0 1 2 3]
# [ 4 5 6 7]
# [ 8 9 10 11]]
b = a + 10
print("\nAfter addition:")
# it still is MatrixExpr not MatrixVariable, without changing to .add's default behavior
# because of __array_ufunc__
print(type(b))
print(b)
# After addition:
# <class '__main__.MatrixExpr'>
# [[10 11 12 13]
# [14 15 16 17]
# [18 19 20 21]]
c = a.sum(axis=1)
print("\nAfter sum:")
# it is still MatrixExpr, without changing to .sum's default behavior
print(type(c))
print(c)
# After sum:
# <class '__main__.MatrixExpr'>
# [ 6 22 38]
d = a == b
print("\nAfter comparison (a == b):")
# it is MatrixExprCons because of __array_wrap__
print(type(d))
print(d)
# After comparison (a == b):
# <class '__main__.MatrixExprCons'>
# [[False False False False]
# [False False False False]
# [False False False False]]Metadata
Metadata
Assignees
Labels
No labels