-
Notifications
You must be signed in to change notification settings - Fork 0
Closed
Description
For the dataset created with the dataset builder, I need stratified splits. There is already a splitter implementation, but it is not sufficient for our multilabel dataset.
Use iterstrat or sklearn packages to build stratified split. Here is an example for a similar library:
def get_test_split(
self, df: pd.DataFrame, seed: Optional[int] = None
) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""
Split the input DataFrame into training and testing sets based on multilabel stratified sampling.
This method uses MultilabelStratifiedShuffleSplit to split the data such that the distribution of labels
in the training and testing sets is approximately the same. The split is based on the "labels" column
in the DataFrame.
Args:
df (pd.DataFrame): The input DataFrame containing the data to be split. It must contain a column
named "labels" with the multilabel data.
seed (int, optional): The random seed to be used for reproducibility. Default is None.
Returns:
Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the training set and testing set DataFrames.
Raises:
ValueError: If the DataFrame does not contain a column named "labels".
"""
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from sklearn.model_selection import StratifiedShuffleSplit
print("Get test data split")
labels_list = df["labels"].tolist()
if len(labels_list[0]) > 1:
splitter = MultilabelStratifiedShuffleSplit(
n_splits=1, test_size=self.test_split, random_state=seed
)
else:
splitter = StratifiedShuffleSplit(
n_splits=1, test_size=self.test_split, random_state=seed
)
train_indices, test_indices = next(splitter.split(labels_list, labels_list))
df_train = df.iloc[train_indices]
df_test = df.iloc[test_indices]
return df_train, df_test
def get_train_val_splits_given_test(
self, df: pd.DataFrame, test_df: pd.DataFrame, seed: int = None
) -> Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]:
"""
Split the dataset into train and validation sets, given a test set.
Use test set (e.g., loaded from another source or generated in get_test_split), to avoid overlap
Args:
df (pd.DataFrame): The original dataset.
test_df (pd.DataFrame): The test dataset.
seed (int, optional): The random seed to be used for reproducibility. Default is None.
Returns:
Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: A dictionary containing train and
validation sets if self.use_inner_cross_validation is True, otherwise a tuple containing the train
and validation DataFrames. The keys are the names of the train and validation sets, and the values
are the corresponding DataFrames.
"""
from iterstrat.ml_stratifiers import (
MultilabelStratifiedKFold,
MultilabelStratifiedShuffleSplit,
)
from sklearn.model_selection import StratifiedShuffleSplit
print("Split dataset into train / val with given test set")
test_ids = test_df["ident"].tolist()
df_trainval = df[~df["ident"].isin(test_ids)]
labels_list_trainval = df_trainval["labels"].tolist()
if self.use_inner_cross_validation:
folds = {}
kfold = MultilabelStratifiedKFold(
n_splits=self.inner_k_folds, random_state=seed
)
for fold, (train_ids, val_ids) in enumerate(
kfold.split(
labels_list_trainval,
labels_list_trainval,
)
):
df_validation = df_trainval.iloc[val_ids]
df_train = df_trainval.iloc[train_ids]
folds[self.raw_file_names_dict[f"fold_{fold}_train"]] = df_train
folds[self.raw_file_names_dict[f"fold_{fold}_validation"]] = (
df_validation
)
return folds
if len(labels_list_trainval[0]) > 1:
splitter = MultilabelStratifiedShuffleSplit(
n_splits=1,
test_size=self.validation_split / (1 - self.test_split),
random_state=seed,
)
else:
splitter = StratifiedShuffleSplit(
n_splits=1,
test_size=self.validation_split / (1 - self.test_split),
random_state=seed,
)
train_indices, validation_indices = next(
splitter.split(labels_list_trainval, labels_list_trainval)
)
df_validation = df_trainval.iloc[validation_indices]
df_train = df_trainval.iloc[train_indices]
return df_train, df_validation
Ignore the part about cross validation and unify both into a single function
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels