1111
1212from larray .core .abstractbases import ABCAxis , ABCAxisReference , ABCArray
1313from larray .core .expr import ExprNode
14- from larray .core .group import (Group , LGroup , IGroup , IGroupMaker , _to_tick , _to_ticks , _to_key , _seq_summary ,
15- _idx_seq_to_slice , _seq_group_to_name , _translate_group_key_hdf , remove_nested_groups )
14+ from larray .core .group import (Group , LGroup , IGroup , IGroupMaker , _to_label , _to_labels , _to_key , _seq_summary ,
15+ _idx_seq_to_slice , _seq_group_to_name , _translate_group_key_hdf , remove_nested_groups ,
16+ _to_label_or_labels )
1617from larray .util .oset import OrderedSet
1718from larray .util .misc import (duplicates , array_lookup2 , ReprString , index_by_id , renamed_to , common_type , LHDFStore ,
1819 lazy_attribute , _isnoneslice , unique_list , unique_multi , Product , argsort , has_duplicates ,
@@ -195,7 +196,7 @@ def labels(self, labels):
195196 labels = np .arange (length )
196197 iswildcard = True
197198 else :
198- labels = _to_ticks (labels , parse_single_int = True )
199+ labels = _to_labels (labels , parse_single_int = True )
199200 length = len (labels )
200201 iswildcard = False
201202
@@ -883,7 +884,7 @@ def _ipython_key_completions_(self) -> List[Scalar]:
883884
884885 def __contains__ (self , key ) -> bool :
885886 # TODO: ideally, _to_tick shouldn't be necessary, the __hash__ and __eq__ of Group should include this
886- return _to_tick (key ) in self ._mapping
887+ return _to_label (key ) in self ._mapping
887888
888889 # use the default hash. We have to specify it explicitly because we define __eq__
889890 __hash__ = object .__hash__
@@ -919,14 +920,17 @@ def index(self, key) -> Union[int, np.ndarray, slice]:
919920 3
920921 >>> people.index(people.containing('Bruce'))
921922 array([1, 2])
923+ >>> a = Axis('a0..a5', 'a')
924+ >>> a.index('a1,a3,a2..a4')
925+ array([1, 3, 2, 3, 4])
922926 """
923927 mapping = self ._mapping
924928
925929 if isinstance (key , Group ) and key .axis is not self and key .axis is not None :
926930 try :
927931 # XXX: this is potentially very expensive if key.key is an array or list and should be tried as a last
928932 # resort
929- potential_tick = _to_tick (key )
933+ potential_tick = _to_label (key )
930934 # avoid matching 0 against False or 0.0, note that None has object dtype and so always pass this test
931935 if self ._is_key_type_compatible (potential_tick ):
932936 return mapping [potential_tick ]
@@ -1121,73 +1125,91 @@ def copy(self) -> 'Axis':
11211125 new_axis .__sorted_values = self .__sorted_values
11221126 return new_axis
11231127
1124- def replace (self , old , new = None ) -> 'Axis' :
1128+ def set_labels (self , old_or_changes , new = None ) -> 'Axis' :
11251129 r"""
1126- Returns a new axis with some labels replaced .
1130+ Returns a new axis with some labels changed .
11271131
1128- Parameters
1129- ----------
1130- old : any scalar (bool, int, str, ...), tuple/list/array of scalars, or a mapping.
1131- the label(s) to be replaced. Old can be a mapping {old1: new1, old2: new2, ...}
1132- new : any scalar (bool, int, str, ...) or tuple/list/array of scalars, optional
1133- the new label(s). This is argument must not be used if old is a mapping.
1132+ It supports three distinct syntax variants:
11341133
1135- Returns
1136- -------
1137- Axis
1138- a new Axis with the old labels replaced by new labels.
1134+ * Axis.set_labels(new_labels) -> replace all Axis labels by `new_labels`
1135+ * Axis.set_labels(label_selection, new_labels) -> replace selection of labels by `new_labels`
1136+ * Axis.set_labels({old1: new1, old2: new2}) -> replace each selection of labels by corresponding new labels
11391137
1140- Examples
1141- --------
1142- >>> sex = Axis('sex=M,F')
1143- >>> sex
1144- Axis(['M', 'F'], 'sex')
1145- >>> sex.replace('M', 'Male')
1146- Axis(['Male', 'F'], 'sex')
1147- >>> sex.replace({'M': 'Male', 'F': 'Female'})
1148- Axis(['Male', 'Female'], 'sex')
1149- >>> sex.replace(['M', 'F'], ['Male', 'Female'])
1150- Axis(['Male', 'Female'], 'sex')
1151- """
1152- if isinstance (old , dict ):
1153- new = list (old .values ())
1154- old = list (old .keys ())
1155- elif np .isscalar (old ):
1156- assert new is not None and np .isscalar (new ), f"{ new } is not a scalar but a { type (new ).__name__ } "
1157- old = [old ]
1158- new = [new ]
1159- else :
1160- seq = (tuple , list , np .ndarray )
1161- assert isinstance (old , seq ), f"{ old } is not a sequence but a { type (old ).__name__ } "
1162- assert isinstance (new , seq ), f"{ new } is not a sequence but a { type (new ).__name__ } "
1163- assert len (old ) == len (new )
1164- # using object dtype because new labels length can be larger than the fixed str length in the self.labels array
1165- labels = self .labels .astype (object )
1166- indices = self .index (old )
1167- labels [indices ] = new
1168- return Axis (labels , self .name )
1169-
1170- def apply (self , func ) -> 'Axis' :
1171- r"""
1172- Returns a new axis with the labels transformed by func.
1138+ Additionally, new labels in any of the above forms can be a function which transforms the existing
1139+ labels to produce the actual new labels.
11731140
11741141 Parameters
11751142 ----------
1176- func : callable
1177- A callable which takes a single argument and returns a single value.
1143+ old_or_changes : any scalar (bool, int, str, ...), tuple/list/array of scalars, Group, callable or mapping.
1144+ This can be either:
1145+
1146+ * A selection of label(s) to be replaced. This can take several forms:
1147+ - a single label (e.g. 'France')
1148+ - a list of labels (e.g. ['France', 'Germany'])
1149+ - a comma-separated string of labels (e.g. 'France,Germany')
1150+ - a Group (e.g. country['France'])
1151+ * A mapping {selection1: new_labels1, selection2: new_labels2, ...}
1152+ * New labels, in which case all the axis labels will be replaced by these new labels and
1153+ the `new` argument must not be used.
1154+ new : any scalar (bool, int, str, ...) or tuple/list/array of scalars or callable, optional
1155+ The new label(s) or function to apply to old labels to get the new labels. This is argument must not be
1156+ used if `old_or_changes` contains the new labels or if it is a mapping.
11781157
11791158 Returns
11801159 -------
11811160 Axis
1182- a new Axis with the transformed labels.
1161+ a new Axis with the old labels replaced by new labels.
11831162
11841163 Examples
11851164 --------
1186- >>> sex = Axis('sex=MALE,FEMALE')
1187- >>> sex.apply(str.capitalize)
1188- Axis(['Male', 'Female'], 'sex')
1189- """
1190- return Axis (np_frompyfunc (func , 1 , 1 )(self .labels ), self .name )
1165+ >>> country = Axis('country=be,de,fr')
1166+ >>> country
1167+ Axis(['be', 'de', 'fr'], 'country')
1168+ >>> country.set_labels('be', 'Belgium')
1169+ Axis(['Belgium', 'de', 'fr'], 'country')
1170+ >>> country.set_labels({'de': 'Germany', 'fr': 'France'})
1171+ Axis(['be', 'Germany', 'France'], 'country')
1172+ >>> country.set_labels(['be', 'fr'], ['Belgium', 'France'])
1173+ Axis(['Belgium', 'de', 'France'], 'country')
1174+ >>> country.set_labels('be,de', 'Belgium-Germany')
1175+ Axis(['Belgium-Germany', 'Belgium-Germany', 'fr'], 'country')
1176+ >>> country.set_labels('be,de', ['Belgium', 'Germany'])
1177+ Axis(['Belgium', 'Germany', 'fr'], 'country')
1178+ >>> country.set_labels(str.upper)
1179+ Axis(['BE', 'DE', 'FR'], 'country')
1180+ """
1181+ # FIXME: compute max(length of new keys and old labels array) instead
1182+ # XXX: it might be easier to go via list to get the label type auto-detection
1183+ # labels = self.labels.tolist()
1184+
1185+ # using object dtype because new labels length can be larger than the fixed str length in self.labels
1186+ labels = self .labels .astype (object )
1187+ get_indices = self .index
1188+
1189+ def apply_changes (selection , label_change ):
1190+ old_indices = get_indices (selection )
1191+ if callable (label_change ):
1192+ old_labels = labels [old_indices ]
1193+ if isinstance (old_labels , np .ndarray ):
1194+ np_func = np_frompyfunc (label_change , 1 , 1 )
1195+ new_labels = np_func (old_labels )
1196+ else :
1197+ new_labels = label_change (old_labels )
1198+ else :
1199+ new_labels = _to_label_or_labels (label_change )
1200+ labels [old_indices ] = new_labels
1201+
1202+ if new is None and not isinstance (old_or_changes , dict ):
1203+ apply_changes (slice (None ), old_or_changes )
1204+ elif new is not None :
1205+ apply_changes (old_or_changes , new )
1206+ else :
1207+ assert new is None and isinstance (old_or_changes , dict )
1208+ for old , new in old_or_changes .items ():
1209+ apply_changes (old , new )
1210+ return Axis (labels , self .name )
1211+ apply = renamed_to (set_labels , 'apply' )
1212+ replace = renamed_to (set_labels , 'replace' )
11911213
11921214 # XXX: rename to named like Group?
11931215 def rename (self , name ) -> 'Axis' :
@@ -1196,7 +1218,7 @@ def rename(self, name) -> 'Axis':
11961218
11971219 Parameters
11981220 ----------
1199- name : str
1221+ name : str, Axis
12001222 the new name for the axis.
12011223
12021224 Returns
@@ -1252,7 +1274,7 @@ def union(self, other) -> 'Axis':
12521274 """
12531275 if isinstance (other , str ):
12541276 # TODO : remove [other] if ... when FuturWarning raised in Axis.init will be removed
1255- other = _to_ticks (other , parse_single_int = True ) if '..' in other or ',' in other else [other ]
1277+ other = _to_labels (other , parse_single_int = True ) if '..' in other or ',' in other else [other ]
12561278 if isinstance (other , Axis ):
12571279 other = other .labels
12581280 return Axis (unique_multi ((self .labels , other )), self .name )
@@ -1288,7 +1310,7 @@ def intersection(self, other) -> 'Axis':
12881310 """
12891311 if isinstance (other , str ):
12901312 # TODO : remove [other] if ... when FuturWarning raised in Axis.init will be removed
1291- other = _to_ticks (other , parse_single_int = True ) if '..' in other or ',' in other else [other ]
1313+ other = _to_labels (other , parse_single_int = True ) if '..' in other or ',' in other else [other ]
12921314 if isinstance (other , Axis ):
12931315 other = other .labels
12941316 to_keep = set (other )
@@ -1325,7 +1347,7 @@ def difference(self, other) -> 'Axis':
13251347 """
13261348 if isinstance (other , str ):
13271349 # TODO : remove [other] if ... when FuturWarning raised in Axis.init will be removed
1328- other = _to_ticks (other , parse_single_int = True ) if '..' in other or ',' in other else [other ]
1350+ other = _to_labels (other , parse_single_int = True ) if '..' in other or ',' in other else [other ]
13291351 if isinstance (other , Axis ):
13301352 other = other .labels
13311353 to_drop = set (other )
@@ -2567,24 +2589,13 @@ def set_labels(self, axis=None, labels=None, inplace=False, **kwargs) -> 'AxisCo
25672589 # handle {label1: new_label1, label2: new_label2}
25682590 if any (axis_ref not in self for axis_ref in changes .keys ()):
25692591 changes_per_axis = defaultdict (list )
2570- for selection , new_labels in changes .items ():
2592+ for selection , label_changes in changes .items ():
25712593 group = self ._guess_axis (selection )
2572- changes_per_axis [group .axis ].append ((selection , new_labels ))
2594+ changes_per_axis [group .axis ].append ((group , label_changes ))
25732595 changes = {axis : dict (axis_changes ) for axis , axis_changes in changes_per_axis .items ()}
25742596
2575- new_axes = []
2576- for old_axis , axis_changes in changes .items ():
2577- real_axis = self [old_axis ]
2578- if isinstance (axis_changes , dict ):
2579- new_axis = real_axis .replace (axis_changes )
2580- # TODO: we should implement the non-dict behavior in Axis.replace, so that we can simplify this code to:
2581- # new_axes = [self[old_axis].replace(axis_changes) for old_axis, axis_changes in changes.items()]
2582- elif callable (axis_changes ):
2583- new_axis = real_axis .apply (axis_changes )
2584- else :
2585- new_axis = Axis (axis_changes , real_axis .name )
2586- new_axes .append ((real_axis , new_axis ))
2587- return self .replace (new_axes , inplace = inplace )
2597+ return self .replace ({old_axis : self [old_axis ].set_labels (axis_changes ) for old_axis , axis_changes in
2598+ changes .items ()}, inplace = inplace )
25882599
25892600 # TODO: deprecate method (should use __sub__ instead)
25902601 def without (self , axes ) -> 'AxisCollection' :
@@ -3428,6 +3439,7 @@ def align(self, other, join='outer', axes=None) -> Tuple['AxisCollection', 'Axis
34283439 See Also
34293440 --------
34303441 Array.align
3442+ Axis.align
34313443
34323444 Examples
34333445 --------
0 commit comments