Skip to content

Commit 305c01c

Browse files
committed
FEAT: generalize set_labels to support groups as keys and "expressions" as labels
array.set_labels({'a0:a1': 'A0..A1'}) (closes #906)
1 parent 7652500 commit 305c01c

6 files changed

Lines changed: 194 additions & 137 deletions

File tree

doc/source/changes/version_0_34.rst.inc

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Syntax changes
55
^^^^^^^^^^^^^^
66

7-
* renamed ``Array.old_method_name()`` to :py:obj:`Array.new_method_name()` (closes :issue:`1`).
7+
* renamed ``Axis.apply()`` and ``Axis.replace()`` are deprecated in favor of :py:obj:`Axis.set_labels()`.
88

99
* renamed ``old_argument_name`` argument of :py:obj:`Array.method_name()` to ``new_argument_name``.
1010

@@ -52,6 +52,24 @@ Miscellaneous improvements
5252
* made all I/O functions/methods/constructors to accept either a string or a pathlib.Path object
5353
for all arguments representing a path (closes :issue:`896`).
5454

55+
* :py:obj:`Array.set_labels()` and :py:obj:`Axis.set_labels()` (formerly ``Axis.replace()`` and ``Axis.apply()``) now
56+
accepts slices, Groups or selection strings as labels to change and callable and "creation strings" as new labels, so
57+
that it is easier to change only a subset of labels or to change several labels in the same way (closes :issue:`906`).
58+
59+
>>> arr = ndtest((2, 3))
60+
>>> arr
61+
a\b b0 b1 b2
62+
a0 0 1 2
63+
a1 3 4 5
64+
>>> arr.set_labels({'b1:': str.upper, 'a1': 'A-ONE'})
65+
a\b b0 B1 B2
66+
a0 0 1 2
67+
A-ONE 3 4 5
68+
>>> arr.set_labels('b1:', 'B1..B2')
69+
a\b b0 B1 B2
70+
a0 0 1 2
71+
a1 3 4 5
72+
5573
* added type hints for all remaining functions and methods which improves autocompletion in editors (such as PyCharm).
5674
Closes :issue:`864`.
5775

larray/core/array.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7424,7 +7424,6 @@ def __array__(self, dtype=None):
74247424

74257425
__array_priority__ = 100
74267426

7427-
# TODO: this should be a thin wrapper around a method in AxisCollection
74287427
def set_labels(self, axis=None, labels=None, inplace=False, **kwargs) -> 'Array':
74297428
r"""Replaces the labels of one or several axes of the array.
74307429
@@ -7522,13 +7521,18 @@ def set_labels(self, axis=None, labels=None, inplace=False, **kwargs) -> 'Array'
75227521
nat\sex Men F
75237522
Belgian 0 1
75247523
FO 2 3
7524+
7525+
>>> a.set_labels({'M:F': str.lower, 'BE': 'Belgian', 'FO': 'Foreigner'})
7526+
nat\sex m f
7527+
Belgian 0 1
7528+
Foreigner 2 3
75257529
"""
7526-
axes = self.axes.set_labels(axis, labels, **kwargs)
7530+
new_axes = self.axes.set_labels(axis, labels, **kwargs)
75277531
if inplace:
7528-
self.axes = axes
7532+
self.axes = new_axes
75297533
return self
75307534
else:
7531-
return Array(self.data, axes)
7535+
return Array(self.data, new_axes)
75327536

75337537
def astype(self, dtype, order='K', casting='unsafe', subok=True, copy=True) -> 'Array':
75347538
return Array(self.data.astype(dtype, order, casting, subok, copy), self.axes)

larray/core/axis.py

Lines changed: 89 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111

1212
from larray.core.abstractbases import ABCAxis, ABCAxisReference, ABCArray
1313
from larray.core.expr import ExprNode
14-
from larray.core.group import (Group, LGroup, IGroup, IGroupMaker, _to_tick, _to_ticks, _to_key, _seq_summary,
15-
_idx_seq_to_slice, _seq_group_to_name, _translate_group_key_hdf, remove_nested_groups)
14+
from larray.core.group import (Group, LGroup, IGroup, IGroupMaker, _to_label, _to_labels, _to_key, _seq_summary,
15+
_idx_seq_to_slice, _seq_group_to_name, _translate_group_key_hdf, remove_nested_groups,
16+
_to_label_or_labels)
1617
from larray.util.oset import OrderedSet
1718
from larray.util.misc import (duplicates, array_lookup2, ReprString, index_by_id, renamed_to, common_type, LHDFStore,
1819
lazy_attribute, _isnoneslice, unique_list, unique_multi, Product, argsort, has_duplicates,
@@ -195,7 +196,7 @@ def labels(self, labels):
195196
labels = np.arange(length)
196197
iswildcard = True
197198
else:
198-
labels = _to_ticks(labels, parse_single_int=True)
199+
labels = _to_labels(labels, parse_single_int=True)
199200
length = len(labels)
200201
iswildcard = False
201202

@@ -883,7 +884,7 @@ def _ipython_key_completions_(self) -> List[Scalar]:
883884

884885
def __contains__(self, key) -> bool:
885886
# TODO: ideally, _to_tick shouldn't be necessary, the __hash__ and __eq__ of Group should include this
886-
return _to_tick(key) in self._mapping
887+
return _to_label(key) in self._mapping
887888

888889
# use the default hash. We have to specify it explicitly because we define __eq__
889890
__hash__ = object.__hash__
@@ -919,14 +920,17 @@ def index(self, key) -> Union[int, np.ndarray, slice]:
919920
3
920921
>>> people.index(people.containing('Bruce'))
921922
array([1, 2])
923+
>>> a = Axis('a0..a5', 'a')
924+
>>> a.index('a1,a3,a2..a4')
925+
array([1, 3, 2, 3, 4])
922926
"""
923927
mapping = self._mapping
924928

925929
if isinstance(key, Group) and key.axis is not self and key.axis is not None:
926930
try:
927931
# XXX: this is potentially very expensive if key.key is an array or list and should be tried as a last
928932
# resort
929-
potential_tick = _to_tick(key)
933+
potential_tick = _to_label(key)
930934
# avoid matching 0 against False or 0.0, note that None has object dtype and so always pass this test
931935
if self._is_key_type_compatible(potential_tick):
932936
return mapping[potential_tick]
@@ -1121,73 +1125,91 @@ def copy(self) -> 'Axis':
11211125
new_axis.__sorted_values = self.__sorted_values
11221126
return new_axis
11231127

1124-
def replace(self, old, new=None) -> 'Axis':
1128+
def set_labels(self, old_or_changes, new=None) -> 'Axis':
11251129
r"""
1126-
Returns a new axis with some labels replaced.
1130+
Returns a new axis with some labels changed.
11271131
1128-
Parameters
1129-
----------
1130-
old : any scalar (bool, int, str, ...), tuple/list/array of scalars, or a mapping.
1131-
the label(s) to be replaced. Old can be a mapping {old1: new1, old2: new2, ...}
1132-
new : any scalar (bool, int, str, ...) or tuple/list/array of scalars, optional
1133-
the new label(s). This is argument must not be used if old is a mapping.
1132+
It supports three distinct syntax variants:
11341133
1135-
Returns
1136-
-------
1137-
Axis
1138-
a new Axis with the old labels replaced by new labels.
1134+
* Axis.set_labels(new_labels) -> replace all Axis labels by `new_labels`
1135+
* Axis.set_labels(label_selection, new_labels) -> replace selection of labels by `new_labels`
1136+
* Axis.set_labels({old1: new1, old2: new2}) -> replace each selection of labels by corresponding new labels
11391137
1140-
Examples
1141-
--------
1142-
>>> sex = Axis('sex=M,F')
1143-
>>> sex
1144-
Axis(['M', 'F'], 'sex')
1145-
>>> sex.replace('M', 'Male')
1146-
Axis(['Male', 'F'], 'sex')
1147-
>>> sex.replace({'M': 'Male', 'F': 'Female'})
1148-
Axis(['Male', 'Female'], 'sex')
1149-
>>> sex.replace(['M', 'F'], ['Male', 'Female'])
1150-
Axis(['Male', 'Female'], 'sex')
1151-
"""
1152-
if isinstance(old, dict):
1153-
new = list(old.values())
1154-
old = list(old.keys())
1155-
elif np.isscalar(old):
1156-
assert new is not None and np.isscalar(new), f"{new} is not a scalar but a {type(new).__name__}"
1157-
old = [old]
1158-
new = [new]
1159-
else:
1160-
seq = (tuple, list, np.ndarray)
1161-
assert isinstance(old, seq), f"{old} is not a sequence but a {type(old).__name__}"
1162-
assert isinstance(new, seq), f"{new} is not a sequence but a {type(new).__name__}"
1163-
assert len(old) == len(new)
1164-
# using object dtype because new labels length can be larger than the fixed str length in the self.labels array
1165-
labels = self.labels.astype(object)
1166-
indices = self.index(old)
1167-
labels[indices] = new
1168-
return Axis(labels, self.name)
1169-
1170-
def apply(self, func) -> 'Axis':
1171-
r"""
1172-
Returns a new axis with the labels transformed by func.
1138+
Additionally, new labels in any of the above forms can be a function which transforms the existing
1139+
labels to produce the actual new labels.
11731140
11741141
Parameters
11751142
----------
1176-
func : callable
1177-
A callable which takes a single argument and returns a single value.
1143+
old_or_changes : any scalar (bool, int, str, ...), tuple/list/array of scalars, Group, callable or mapping.
1144+
This can be either:
1145+
1146+
* A selection of label(s) to be replaced. This can take several forms:
1147+
- a single label (e.g. 'France')
1148+
- a list of labels (e.g. ['France', 'Germany'])
1149+
- a comma-separated string of labels (e.g. 'France,Germany')
1150+
- a Group (e.g. country['France'])
1151+
* A mapping {selection1: new_labels1, selection2: new_labels2, ...}
1152+
* New labels, in which case all the axis labels will be replaced by these new labels and
1153+
the `new` argument must not be used.
1154+
new : any scalar (bool, int, str, ...) or tuple/list/array of scalars or callable, optional
1155+
The new label(s) or function to apply to old labels to get the new labels. This is argument must not be
1156+
used if `old_or_changes` contains the new labels or if it is a mapping.
11781157
11791158
Returns
11801159
-------
11811160
Axis
1182-
a new Axis with the transformed labels.
1161+
a new Axis with the old labels replaced by new labels.
11831162
11841163
Examples
11851164
--------
1186-
>>> sex = Axis('sex=MALE,FEMALE')
1187-
>>> sex.apply(str.capitalize)
1188-
Axis(['Male', 'Female'], 'sex')
1189-
"""
1190-
return Axis(np_frompyfunc(func, 1, 1)(self.labels), self.name)
1165+
>>> country = Axis('country=be,de,fr')
1166+
>>> country
1167+
Axis(['be', 'de', 'fr'], 'country')
1168+
>>> country.set_labels('be', 'Belgium')
1169+
Axis(['Belgium', 'de', 'fr'], 'country')
1170+
>>> country.set_labels({'de': 'Germany', 'fr': 'France'})
1171+
Axis(['be', 'Germany', 'France'], 'country')
1172+
>>> country.set_labels(['be', 'fr'], ['Belgium', 'France'])
1173+
Axis(['Belgium', 'de', 'France'], 'country')
1174+
>>> country.set_labels('be,de', 'Belgium-Germany')
1175+
Axis(['Belgium-Germany', 'Belgium-Germany', 'fr'], 'country')
1176+
>>> country.set_labels('be,de', ['Belgium', 'Germany'])
1177+
Axis(['Belgium', 'Germany', 'fr'], 'country')
1178+
>>> country.set_labels(str.upper)
1179+
Axis(['BE', 'DE', 'FR'], 'country')
1180+
"""
1181+
# FIXME: compute max(length of new keys and old labels array) instead
1182+
# XXX: it might be easier to go via list to get the label type auto-detection
1183+
# labels = self.labels.tolist()
1184+
1185+
# using object dtype because new labels length can be larger than the fixed str length in self.labels
1186+
labels = self.labels.astype(object)
1187+
get_indices = self.index
1188+
1189+
def apply_changes(selection, label_change):
1190+
old_indices = get_indices(selection)
1191+
if callable(label_change):
1192+
old_labels = labels[old_indices]
1193+
if isinstance(old_labels, np.ndarray):
1194+
np_func = np_frompyfunc(label_change, 1, 1)
1195+
new_labels = np_func(old_labels)
1196+
else:
1197+
new_labels = label_change(old_labels)
1198+
else:
1199+
new_labels = _to_label_or_labels(label_change)
1200+
labels[old_indices] = new_labels
1201+
1202+
if new is None and not isinstance(old_or_changes, dict):
1203+
apply_changes(slice(None), old_or_changes)
1204+
elif new is not None:
1205+
apply_changes(old_or_changes, new)
1206+
else:
1207+
assert new is None and isinstance(old_or_changes, dict)
1208+
for old, new in old_or_changes.items():
1209+
apply_changes(old, new)
1210+
return Axis(labels, self.name)
1211+
apply = renamed_to(set_labels, 'apply')
1212+
replace = renamed_to(set_labels, 'replace')
11911213

11921214
# XXX: rename to named like Group?
11931215
def rename(self, name) -> 'Axis':
@@ -1196,7 +1218,7 @@ def rename(self, name) -> 'Axis':
11961218
11971219
Parameters
11981220
----------
1199-
name : str
1221+
name : str, Axis
12001222
the new name for the axis.
12011223
12021224
Returns
@@ -1252,7 +1274,7 @@ def union(self, other) -> 'Axis':
12521274
"""
12531275
if isinstance(other, str):
12541276
# TODO : remove [other] if ... when FuturWarning raised in Axis.init will be removed
1255-
other = _to_ticks(other, parse_single_int=True) if '..' in other or ',' in other else [other]
1277+
other = _to_labels(other, parse_single_int=True) if '..' in other or ',' in other else [other]
12561278
if isinstance(other, Axis):
12571279
other = other.labels
12581280
return Axis(unique_multi((self.labels, other)), self.name)
@@ -1288,7 +1310,7 @@ def intersection(self, other) -> 'Axis':
12881310
"""
12891311
if isinstance(other, str):
12901312
# TODO : remove [other] if ... when FuturWarning raised in Axis.init will be removed
1291-
other = _to_ticks(other, parse_single_int=True) if '..' in other or ',' in other else [other]
1313+
other = _to_labels(other, parse_single_int=True) if '..' in other or ',' in other else [other]
12921314
if isinstance(other, Axis):
12931315
other = other.labels
12941316
to_keep = set(other)
@@ -1325,7 +1347,7 @@ def difference(self, other) -> 'Axis':
13251347
"""
13261348
if isinstance(other, str):
13271349
# TODO : remove [other] if ... when FuturWarning raised in Axis.init will be removed
1328-
other = _to_ticks(other, parse_single_int=True) if '..' in other or ',' in other else [other]
1350+
other = _to_labels(other, parse_single_int=True) if '..' in other or ',' in other else [other]
13291351
if isinstance(other, Axis):
13301352
other = other.labels
13311353
to_drop = set(other)
@@ -2567,24 +2589,13 @@ def set_labels(self, axis=None, labels=None, inplace=False, **kwargs) -> 'AxisCo
25672589
# handle {label1: new_label1, label2: new_label2}
25682590
if any(axis_ref not in self for axis_ref in changes.keys()):
25692591
changes_per_axis = defaultdict(list)
2570-
for selection, new_labels in changes.items():
2592+
for selection, label_changes in changes.items():
25712593
group = self._guess_axis(selection)
2572-
changes_per_axis[group.axis].append((selection, new_labels))
2594+
changes_per_axis[group.axis].append((group, label_changes))
25732595
changes = {axis: dict(axis_changes) for axis, axis_changes in changes_per_axis.items()}
25742596

2575-
new_axes = []
2576-
for old_axis, axis_changes in changes.items():
2577-
real_axis = self[old_axis]
2578-
if isinstance(axis_changes, dict):
2579-
new_axis = real_axis.replace(axis_changes)
2580-
# TODO: we should implement the non-dict behavior in Axis.replace, so that we can simplify this code to:
2581-
# new_axes = [self[old_axis].replace(axis_changes) for old_axis, axis_changes in changes.items()]
2582-
elif callable(axis_changes):
2583-
new_axis = real_axis.apply(axis_changes)
2584-
else:
2585-
new_axis = Axis(axis_changes, real_axis.name)
2586-
new_axes.append((real_axis, new_axis))
2587-
return self.replace(new_axes, inplace=inplace)
2597+
return self.replace({old_axis: self[old_axis].set_labels(axis_changes) for old_axis, axis_changes in
2598+
changes.items()}, inplace=inplace)
25882599

25892600
# TODO: deprecate method (should use __sub__ instead)
25902601
def without(self, axes) -> 'AxisCollection':
@@ -3428,6 +3439,7 @@ def align(self, other, join='outer', axes=None) -> Tuple['AxisCollection', 'Axis
34283439
See Also
34293440
--------
34303441
Array.align
3442+
Axis.align
34313443
34323444
Examples
34333445
--------

0 commit comments

Comments
 (0)