Skip to content

Stratified splits for dataset #3

@sfluegel05

Description

@sfluegel05

For the dataset created with the dataset builder, I need stratified splits. There is already a splitter implementation, but it is not sufficient for our multilabel dataset.

Use iterstrat or sklearn packages to build stratified split. Here is an example for a similar library:

def get_test_split(
        self, df: pd.DataFrame, seed: Optional[int] = None
    ) -> Tuple[pd.DataFrame, pd.DataFrame]:
        """
        Split the input DataFrame into training and testing sets based on multilabel stratified sampling.

        This method uses MultilabelStratifiedShuffleSplit to split the data such that the distribution of labels
        in the training and testing sets is approximately the same. The split is based on the "labels" column
        in the DataFrame.

        Args:
            df (pd.DataFrame): The input DataFrame containing the data to be split. It must contain a column
                               named "labels" with the multilabel data.
            seed (int, optional): The random seed to be used for reproducibility. Default is None.

        Returns:
            Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the training set and testing set DataFrames.

        Raises:
            ValueError: If the DataFrame does not contain a column named "labels".
        """
        from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
        from sklearn.model_selection import StratifiedShuffleSplit

        print("Get test data split")

        labels_list = df["labels"].tolist()

        if len(labels_list[0]) > 1:
            splitter = MultilabelStratifiedShuffleSplit(
                n_splits=1, test_size=self.test_split, random_state=seed
            )
        else:
            splitter = StratifiedShuffleSplit(
                n_splits=1, test_size=self.test_split, random_state=seed
            )

        train_indices, test_indices = next(splitter.split(labels_list, labels_list))

        df_train = df.iloc[train_indices]
        df_test = df.iloc[test_indices]
        return df_train, df_test

    def get_train_val_splits_given_test(
        self, df: pd.DataFrame, test_df: pd.DataFrame, seed: int = None
    ) -> Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]:
        """
        Split the dataset into train and validation sets, given a test set.
        Use test set (e.g., loaded from another source or generated in get_test_split), to avoid overlap

        Args:
            df (pd.DataFrame): The original dataset.
            test_df (pd.DataFrame): The test dataset.
            seed (int, optional): The random seed to be used for reproducibility. Default is None.

        Returns:
            Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: A dictionary containing train and
                validation sets if self.use_inner_cross_validation is True, otherwise a tuple containing the train
                and validation DataFrames. The keys are the names of the train and validation sets, and the values
                are the corresponding DataFrames.
        """
        from iterstrat.ml_stratifiers import (
            MultilabelStratifiedKFold,
            MultilabelStratifiedShuffleSplit,
        )
        from sklearn.model_selection import StratifiedShuffleSplit

        print("Split dataset into train / val with given test set")

        test_ids = test_df["ident"].tolist()
        df_trainval = df[~df["ident"].isin(test_ids)]
        labels_list_trainval = df_trainval["labels"].tolist()

        if self.use_inner_cross_validation:
            folds = {}
            kfold = MultilabelStratifiedKFold(
                n_splits=self.inner_k_folds, random_state=seed
            )
            for fold, (train_ids, val_ids) in enumerate(
                kfold.split(
                    labels_list_trainval,
                    labels_list_trainval,
                )
            ):
                df_validation = df_trainval.iloc[val_ids]
                df_train = df_trainval.iloc[train_ids]
                folds[self.raw_file_names_dict[f"fold_{fold}_train"]] = df_train
                folds[self.raw_file_names_dict[f"fold_{fold}_validation"]] = (
                    df_validation
                )

            return folds

        if len(labels_list_trainval[0]) > 1:
            splitter = MultilabelStratifiedShuffleSplit(
                n_splits=1,
                test_size=self.validation_split / (1 - self.test_split),
                random_state=seed,
            )
        else:
            splitter = StratifiedShuffleSplit(
                n_splits=1,
                test_size=self.validation_split / (1 - self.test_split),
                random_state=seed,
            )

        train_indices, validation_indices = next(
            splitter.split(labels_list_trainval, labels_list_trainval)
        )

        df_validation = df_trainval.iloc[validation_indices]
        df_train = df_trainval.iloc[train_indices]
        return df_train, df_validation

Ignore the part about cross validation and unify both into a single function

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions