From da7c4331377fc943765cd19bd63bf1c640c7907b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 27 Feb 2026 12:14:45 +0000 Subject: [PATCH 1/7] Initial plan From b6d4cbcf2db803d4f9d0953ea80da3dadbc9003b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 27 Feb 2026 12:19:47 +0000 Subject: [PATCH 2/7] Add create_multilabel_splits with iterstrat/sklearn for multilabel stratified splits Co-authored-by: sfluegel05 <43573433+sfluegel05@users.noreply.github.com> --- chebi_utils/__init__.py | 3 +- chebi_utils/splitter.py | 110 ++++++++++++++++++++++++++++++++++++++++ pyproject.toml | 2 + tests/test_splitter.py | 88 +++++++++++++++++++++++++++++++- 4 files changed, 201 insertions(+), 2 deletions(-) diff --git a/chebi_utils/__init__.py b/chebi_utils/__init__.py index 8367bab..ca290c3 100644 --- a/chebi_utils/__init__.py +++ b/chebi_utils/__init__.py @@ -1,7 +1,7 @@ from chebi_utils.downloader import download_chebi_obo, download_chebi_sdf from chebi_utils.obo_extractor import build_chebi_graph from chebi_utils.sdf_extractor import extract_molecules -from chebi_utils.splitter import create_splits +from chebi_utils.splitter import create_multilabel_splits, create_splits __all__ = [ "download_chebi_obo", @@ -9,4 +9,5 @@ "build_chebi_graph", "extract_molecules", "create_splits", + "create_multilabel_splits", ] diff --git a/chebi_utils/splitter.py b/chebi_utils/splitter.py index dd86141..2b14555 100644 --- a/chebi_utils/splitter.py +++ b/chebi_utils/splitter.py @@ -6,6 +6,116 @@ import pandas as pd +def create_multilabel_splits( + df: pd.DataFrame, + labels_col: str, + train_ratio: float = 0.8, + val_ratio: float = 0.1, + test_ratio: float = 0.1, + seed: int | None = 42, +) -> dict[str, pd.DataFrame]: + """Create stratified train/validation/test splits for multilabel DataFrames. + + Automatically detects whether the dataset is multilabel (each entry has + more than one label) or single-label, and applies the appropriate + stratification strategy: + + - Multilabel: uses ``MultilabelStratifiedShuffleSplit`` from the + ``iterative-stratification`` package. + - Single-label: uses ``StratifiedShuffleSplit`` from ``scikit-learn``. + + Parameters + ---------- + df : pd.DataFrame + Input data to split. Must contain a column ``labels_col`` whose + values are sequences of labels (e.g. lists of strings or ints). + labels_col : str + Name of the column that contains the label sequences. + train_ratio : float + Fraction of data for training (default 0.8). + val_ratio : float + Fraction of data for validation (default 0.1). + test_ratio : float + Fraction of data for testing (default 0.1). + seed : int or None + Random seed for reproducibility. + + Returns + ------- + dict + Dictionary with keys ``'train'``, ``'val'``, ``'test'``, each + containing a DataFrame. + + Raises + ------ + ValueError + If the ratios do not sum to 1, any ratio is outside ``[0, 1]``, or + ``labels_col`` is not found in *df*. + """ + if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6: + raise ValueError("train_ratio + val_ratio + test_ratio must equal 1.0") + if any(r < 0 or r > 1 for r in [train_ratio, val_ratio, test_ratio]): + raise ValueError("All ratios must be between 0 and 1") + if labels_col not in df.columns: + raise ValueError(f"Column '{labels_col}' not found in DataFrame") + + from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit + from sklearn.model_selection import StratifiedShuffleSplit + from sklearn.preprocessing import MultiLabelBinarizer + + labels_list: list[list] = df[labels_col].tolist() + is_multilabel = any(len(lbl) > 1 for lbl in labels_list) + + df_reset = df.reset_index(drop=True) + + if is_multilabel: + mlb = MultiLabelBinarizer() + labels_matrix = mlb.fit_transform(labels_list) + else: + labels_matrix = [lbl[0] for lbl in labels_list] + + # ── Step 1: carve out the test set ────────────────────────────────────── + if is_multilabel: + test_splitter = MultilabelStratifiedShuffleSplit( + n_splits=1, test_size=test_ratio, random_state=seed + ) + train_val_idx, test_idx = next(test_splitter.split(labels_matrix, labels_matrix)) + else: + test_splitter = StratifiedShuffleSplit(n_splits=1, test_size=test_ratio, random_state=seed) + train_val_idx, test_idx = next(test_splitter.split(labels_matrix, labels_matrix)) + + df_test = df_reset.iloc[test_idx] + df_trainval = df_reset.iloc[train_val_idx] + + # ── Step 2: split train/val from the remaining data ───────────────────── + labels_trainval = ( + labels_matrix[train_val_idx] + if is_multilabel + else [labels_matrix[i] for i in train_val_idx] + ) + val_ratio_adjusted = val_ratio / (1.0 - test_ratio) + + if is_multilabel: + val_splitter = MultilabelStratifiedShuffleSplit( + n_splits=1, test_size=val_ratio_adjusted, random_state=seed + ) + train_idx_inner, val_idx_inner = next(val_splitter.split(labels_trainval, labels_trainval)) + else: + val_splitter = StratifiedShuffleSplit( + n_splits=1, test_size=val_ratio_adjusted, random_state=seed + ) + train_idx_inner, val_idx_inner = next(val_splitter.split(labels_trainval, labels_trainval)) + + df_train = df_trainval.iloc[train_idx_inner] + df_val = df_trainval.iloc[val_idx_inner] + + return { + "train": df_train.reset_index(drop=True), + "val": df_val.reset_index(drop=True), + "test": df_test.reset_index(drop=True), + } + + def create_splits( df: pd.DataFrame, train_ratio: float = 0.8, diff --git a/pyproject.toml b/pyproject.toml index 86b4fe7..4be79f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,10 +11,12 @@ license = { file = "LICENSE" } requires-python = ">=3.10" dependencies = [ "fastobo>=0.14", + "iterative-stratification>=0.1.9", "networkx>=3.0", "numpy>=1.24", "pandas>=2.0", "rdkit>=2022.09", + "scikit-learn>=1.0", "chembl_structure_pipeline>=1.2.4", ] diff --git a/tests/test_splitter.py b/tests/test_splitter.py index 10b9434..2506a4f 100644 --- a/tests/test_splitter.py +++ b/tests/test_splitter.py @@ -5,7 +5,7 @@ import pandas as pd import pytest -from chebi_utils.splitter import create_splits +from chebi_utils.splitter import create_multilabel_splits, create_splits @pytest.fixture @@ -103,3 +103,89 @@ def test_stratified_reproducible(self, sample_df): splits1 = create_splits(sample_df, stratify_col="category", seed=42) splits2 = create_splits(sample_df, stratify_col="category", seed=42) pd.testing.assert_frame_equal(splits1["train"], splits2["train"]) + + +@pytest.fixture +def multilabel_df(): + """DataFrame with multilabel 'labels' column (200 rows).""" + all_labels = [["A"], ["B"], ["C"], ["A", "B"], ["A", "C"], ["B", "C"]] + labels = [all_labels[i % len(all_labels)] for i in range(200)] + return pd.DataFrame( + { + "id": [f"CHEBI:{i}" for i in range(200)], + "labels": labels, + } + ) + + +@pytest.fixture +def singlelabel_df(): + """DataFrame with single-label 'labels' column.""" + return pd.DataFrame( + { + "id": [f"CHEBI:{i}" for i in range(200)], + "labels": [["A"] if i % 2 == 0 else ["B"] for i in range(200)], + } + ) + + +class TestCreateMultilabelSplits: + def test_returns_three_splits(self, multilabel_df): + splits = create_multilabel_splits(multilabel_df, labels_col="labels") + assert set(splits.keys()) == {"train", "val", "test"} + + def test_sizes_sum_to_total(self, multilabel_df): + splits = create_multilabel_splits(multilabel_df, labels_col="labels") + assert sum(len(v) for v in splits.values()) == len(multilabel_df) + + def test_no_overlap(self, multilabel_df): + splits = create_multilabel_splits(multilabel_df, labels_col="labels") + train_ids = set(splits["train"]["id"]) + val_ids = set(splits["val"]["id"]) + test_ids = set(splits["test"]["id"]) + assert train_ids.isdisjoint(val_ids) + assert train_ids.isdisjoint(test_ids) + assert val_ids.isdisjoint(test_ids) + + def test_all_rows_covered(self, multilabel_df): + splits = create_multilabel_splits(multilabel_df, labels_col="labels") + all_ids = set(splits["train"]["id"]) | set(splits["val"]["id"]) | set(splits["test"]["id"]) + assert all_ids == set(multilabel_df["id"]) + + def test_reproducible_with_same_seed(self, multilabel_df): + splits1 = create_multilabel_splits(multilabel_df, labels_col="labels", seed=7) + splits2 = create_multilabel_splits(multilabel_df, labels_col="labels", seed=7) + pd.testing.assert_frame_equal(splits1["train"], splits2["train"]) + + def test_different_seeds_give_different_splits(self, multilabel_df): + splits1 = create_multilabel_splits(multilabel_df, labels_col="labels", seed=1) + splits2 = create_multilabel_splits(multilabel_df, labels_col="labels", seed=2) + assert not splits1["train"]["id"].equals(splits2["train"]["id"]) + + def test_approximate_split_sizes(self, multilabel_df): + splits = create_multilabel_splits( + multilabel_df, labels_col="labels", train_ratio=0.8, val_ratio=0.1, test_ratio=0.1 + ) + n = len(multilabel_df) + assert abs(len(splits["test"]) - int(n * 0.1)) <= 2 + assert abs(len(splits["val"]) - int(n * 0.1)) <= 2 + + def test_invalid_ratios_raise_error(self, multilabel_df): + with pytest.raises(ValueError, match="must equal 1.0"): + create_multilabel_splits( + multilabel_df, labels_col="labels", train_ratio=0.5, val_ratio=0.3, test_ratio=0.3 + ) + + def test_missing_labels_col_raises_error(self, multilabel_df): + with pytest.raises(ValueError, match="not found in DataFrame"): + create_multilabel_splits(multilabel_df, labels_col="nonexistent") + + def test_singlelabel_path(self, singlelabel_df): + """Single-label lists should use StratifiedShuffleSplit without error.""" + splits = create_multilabel_splits(singlelabel_df, labels_col="labels") + assert sum(len(v) for v in splits.values()) == len(singlelabel_df) + train_ids = set(splits["train"]["id"]) + val_ids = set(splits["val"]["id"]) + test_ids = set(splits["test"]["id"]) + assert train_ids.isdisjoint(val_ids) + assert train_ids.isdisjoint(test_ids) From ab26a27433021e822545f5b6dfdcc12f8c8d5375 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 27 Feb 2026 13:09:32 +0000 Subject: [PATCH 3/7] Rework splitter: use per-label bool columns, remove create_splits Co-authored-by: sfluegel05 <43573433+sfluegel05@users.noreply.github.com> --- chebi_utils/__init__.py | 3 +- chebi_utils/splitter.py | 158 ++++++-------------------------- tests/test_splitter.py | 196 +++++++++++++--------------------------- 3 files changed, 91 insertions(+), 266 deletions(-) diff --git a/chebi_utils/__init__.py b/chebi_utils/__init__.py index ca290c3..36abf0a 100644 --- a/chebi_utils/__init__.py +++ b/chebi_utils/__init__.py @@ -1,13 +1,12 @@ from chebi_utils.downloader import download_chebi_obo, download_chebi_sdf from chebi_utils.obo_extractor import build_chebi_graph from chebi_utils.sdf_extractor import extract_molecules -from chebi_utils.splitter import create_multilabel_splits, create_splits +from chebi_utils.splitter import create_multilabel_splits __all__ = [ "download_chebi_obo", "download_chebi_sdf", "build_chebi_graph", "extract_molecules", - "create_splits", "create_multilabel_splits", ] diff --git a/chebi_utils/splitter.py b/chebi_utils/splitter.py index 2b14555..ead7129 100644 --- a/chebi_utils/splitter.py +++ b/chebi_utils/splitter.py @@ -2,13 +2,12 @@ from __future__ import annotations -import numpy as np import pandas as pd def create_multilabel_splits( df: pd.DataFrame, - labels_col: str, + label_start_col: int = 2, train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, @@ -16,21 +15,23 @@ def create_multilabel_splits( ) -> dict[str, pd.DataFrame]: """Create stratified train/validation/test splits for multilabel DataFrames. - Automatically detects whether the dataset is multilabel (each entry has - more than one label) or single-label, and applies the appropriate - stratification strategy: + Columns from index *label_start_col* onwards are treated as binary label + columns (one boolean column per label). The stratification strategy is + chosen automatically based on the number of label columns: - - Multilabel: uses ``MultilabelStratifiedShuffleSplit`` from the - ``iterative-stratification`` package. - - Single-label: uses ``StratifiedShuffleSplit`` from ``scikit-learn``. + - More than one label column: ``MultilabelStratifiedShuffleSplit`` from + the ``iterative-stratification`` package. + - Single label column: ``StratifiedShuffleSplit`` from ``scikit-learn``. Parameters ---------- df : pd.DataFrame - Input data to split. Must contain a column ``labels_col`` whose - values are sequences of labels (e.g. lists of strings or ints). - labels_col : str - Name of the column that contains the label sequences. + Input data. Columns ``0`` to ``label_start_col - 1`` are treated as + feature/metadata columns; all remaining columns are boolean label + columns. A typical ChEBI DataFrame has columns + ``["chebi_id", "mol", "label1", "label2", ...]``. + label_start_col : int + Index of the first label column (default 2). train_ratio : float Fraction of data for training (default 0.8). val_ratio : float @@ -50,61 +51,53 @@ def create_multilabel_splits( ------ ValueError If the ratios do not sum to 1, any ratio is outside ``[0, 1]``, or - ``labels_col`` is not found in *df*. + *label_start_col* is out of range. """ if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6: raise ValueError("train_ratio + val_ratio + test_ratio must equal 1.0") if any(r < 0 or r > 1 for r in [train_ratio, val_ratio, test_ratio]): raise ValueError("All ratios must be between 0 and 1") - if labels_col not in df.columns: - raise ValueError(f"Column '{labels_col}' not found in DataFrame") + if label_start_col >= len(df.columns): + raise ValueError( + f"label_start_col={label_start_col} is out of range for a DataFrame " + f"with {len(df.columns)} columns" + ) from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit from sklearn.model_selection import StratifiedShuffleSplit - from sklearn.preprocessing import MultiLabelBinarizer - labels_list: list[list] = df[labels_col].tolist() - is_multilabel = any(len(lbl) > 1 for lbl in labels_list) + labels_matrix = df.iloc[:, label_start_col:].values + is_multilabel = labels_matrix.shape[1] > 1 + # StratifiedShuffleSplit requires a 1-D label array + y = labels_matrix if is_multilabel else labels_matrix[:, 0] df_reset = df.reset_index(drop=True) - if is_multilabel: - mlb = MultiLabelBinarizer() - labels_matrix = mlb.fit_transform(labels_list) - else: - labels_matrix = [lbl[0] for lbl in labels_list] - # ── Step 1: carve out the test set ────────────────────────────────────── if is_multilabel: test_splitter = MultilabelStratifiedShuffleSplit( n_splits=1, test_size=test_ratio, random_state=seed ) - train_val_idx, test_idx = next(test_splitter.split(labels_matrix, labels_matrix)) else: test_splitter = StratifiedShuffleSplit(n_splits=1, test_size=test_ratio, random_state=seed) - train_val_idx, test_idx = next(test_splitter.split(labels_matrix, labels_matrix)) + train_val_idx, test_idx = next(test_splitter.split(y, y)) df_test = df_reset.iloc[test_idx] df_trainval = df_reset.iloc[train_val_idx] # ── Step 2: split train/val from the remaining data ───────────────────── - labels_trainval = ( - labels_matrix[train_val_idx] - if is_multilabel - else [labels_matrix[i] for i in train_val_idx] - ) + y_trainval = y[train_val_idx] val_ratio_adjusted = val_ratio / (1.0 - test_ratio) if is_multilabel: val_splitter = MultilabelStratifiedShuffleSplit( n_splits=1, test_size=val_ratio_adjusted, random_state=seed ) - train_idx_inner, val_idx_inner = next(val_splitter.split(labels_trainval, labels_trainval)) else: val_splitter = StratifiedShuffleSplit( n_splits=1, test_size=val_ratio_adjusted, random_state=seed ) - train_idx_inner, val_idx_inner = next(val_splitter.split(labels_trainval, labels_trainval)) + train_idx_inner, val_idx_inner = next(val_splitter.split(y_trainval, y_trainval)) df_train = df_trainval.iloc[train_idx_inner] df_val = df_trainval.iloc[val_idx_inner] @@ -114,102 +107,3 @@ def create_multilabel_splits( "val": df_val.reset_index(drop=True), "test": df_test.reset_index(drop=True), } - - -def create_splits( - df: pd.DataFrame, - train_ratio: float = 0.8, - val_ratio: float = 0.1, - test_ratio: float = 0.1, - stratify_col: str | None = None, - seed: int = 42, -) -> dict[str, pd.DataFrame]: - """Create stratified train/validation/test splits of a DataFrame. - - Parameters - ---------- - df : pd.DataFrame - Input data to split. - train_ratio : float - Fraction of data for training (default 0.8). - val_ratio : float - Fraction of data for validation (default 0.1). - test_ratio : float - Fraction of data for testing (default 0.1). - stratify_col : str or None - Column name to use for stratification. If None, splits are random. - seed : int - Random seed for reproducibility. - - Returns - ------- - dict - Dictionary with keys ``'train'``, ``'val'``, ``'test'``, each - containing a DataFrame. - - Raises - ------ - ValueError - If the ratios do not sum to 1 or any ratio is outside ``[0, 1]``. - """ - if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6: - raise ValueError("train_ratio + val_ratio + test_ratio must equal 1.0") - if any(r < 0 or r > 1 for r in [train_ratio, val_ratio, test_ratio]): - raise ValueError("All ratios must be between 0 and 1") - - rng = np.random.default_rng(seed) - - if stratify_col is not None: - return _stratified_split(df, train_ratio, val_ratio, test_ratio, stratify_col, rng) - return _random_split(df, train_ratio, val_ratio, test_ratio, rng) - - -def _random_split( - df: pd.DataFrame, - train_ratio: float, - val_ratio: float, - test_ratio: float, # noqa: ARG001 - rng: np.random.Generator, -) -> dict[str, pd.DataFrame]: - indices = rng.permutation(len(df)) - n_train = int(len(df) * train_ratio) - n_val = int(len(df) * val_ratio) - - train_idx = indices[:n_train] - val_idx = indices[n_train : n_train + n_val] - test_idx = indices[n_train + n_val :] - - return { - "train": df.iloc[train_idx].reset_index(drop=True), - "val": df.iloc[val_idx].reset_index(drop=True), - "test": df.iloc[test_idx].reset_index(drop=True), - } - - -def _stratified_split( - df: pd.DataFrame, - train_ratio: float, - val_ratio: float, - test_ratio: float, # noqa: ARG001 - stratify_col: str, - rng: np.random.Generator, -) -> dict[str, pd.DataFrame]: - train_indices: list[int] = [] - val_indices: list[int] = [] - test_indices: list[int] = [] - - for _, group in df.groupby(stratify_col, sort=False): - group_indices = rng.permutation(np.array(group.index.tolist())) - n = len(group_indices) - n_train = max(1, int(n * train_ratio)) - n_val = max(0, int(n * val_ratio)) - - train_indices.extend(group_indices[:n_train].tolist()) - val_indices.extend(group_indices[n_train : n_train + n_val].tolist()) - test_indices.extend(group_indices[n_train + n_val :].tolist()) - - return { - "train": df.loc[train_indices].reset_index(drop=True), - "val": df.loc[val_indices].reset_index(drop=True), - "test": df.loc[test_indices].reset_index(drop=True), - } diff --git a/tests/test_splitter.py b/tests/test_splitter.py index 2506a4f..0f5e470 100644 --- a/tests/test_splitter.py +++ b/tests/test_splitter.py @@ -5,187 +5,119 @@ import pandas as pd import pytest -from chebi_utils.splitter import create_multilabel_splits, create_splits - - -@pytest.fixture -def sample_df(): - return pd.DataFrame( - { - "id": [f"CHEBI:{i}" for i in range(100)], - "category": (["A"] * 50) + (["B"] * 30) + (["C"] * 20), - } - ) - - -class TestCreateSplitsRandom: - def test_returns_three_splits(self, sample_df): - splits = create_splits(sample_df) - assert set(splits.keys()) == {"train", "val", "test"} - - def test_split_sizes_sum_to_total(self, sample_df): - splits = create_splits(sample_df) - total = sum(len(v) for v in splits.values()) - assert total == len(sample_df) - - def test_default_ratios(self, sample_df): - splits = create_splits(sample_df) - assert len(splits["train"]) == 80 - assert len(splits["val"]) == 10 - assert len(splits["test"]) == 10 - - def test_reproducible_with_same_seed(self, sample_df): - splits1 = create_splits(sample_df, seed=0) - splits2 = create_splits(sample_df, seed=0) - pd.testing.assert_frame_equal(splits1["train"], splits2["train"]) - - def test_different_seeds_give_different_splits(self, sample_df): - splits1 = create_splits(sample_df, seed=0) - splits2 = create_splits(sample_df, seed=1) - assert not splits1["train"]["id"].equals(splits2["train"]["id"]) - - def test_no_overlap_between_splits(self, sample_df): - splits = create_splits(sample_df) - train_ids = set(splits["train"]["id"]) - val_ids = set(splits["val"]["id"]) - test_ids = set(splits["test"]["id"]) - assert train_ids.isdisjoint(val_ids) - assert train_ids.isdisjoint(test_ids) - assert val_ids.isdisjoint(test_ids) - - def test_all_rows_covered(self, sample_df): - splits = create_splits(sample_df) - all_ids = set(splits["train"]["id"]) | set(splits["val"]["id"]) | set(splits["test"]["id"]) - assert all_ids == set(sample_df["id"]) - - def test_custom_ratios(self, sample_df): - splits = create_splits(sample_df, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1) - assert len(splits["train"]) == 70 - assert len(splits["val"]) == 20 - assert len(splits["test"]) == 10 - - def test_invalid_ratios_raise_error(self, sample_df): - with pytest.raises(ValueError, match="must equal 1.0"): - create_splits(sample_df, train_ratio=0.5, val_ratio=0.3, test_ratio=0.3) - - def test_negative_ratio_raises_error(self, sample_df): - with pytest.raises(ValueError, match="between 0 and 1"): - create_splits(sample_df, train_ratio=-0.1, val_ratio=0.6, test_ratio=0.5) - - -class TestCreateSplitsStratified: - def test_stratified_split_returns_three_splits(self, sample_df): - splits = create_splits(sample_df, stratify_col="category") - assert set(splits.keys()) == {"train", "val", "test"} - - def test_stratified_sizes_sum_to_total(self, sample_df): - splits = create_splits(sample_df, stratify_col="category") - total = sum(len(v) for v in splits.values()) - assert total == len(sample_df) - - def test_stratified_preserves_class_proportions(self, sample_df): - splits = create_splits(sample_df, stratify_col="category") - train = splits["train"] - counts = train["category"].value_counts(normalize=True) - # Category A: 50% of data -> ~50% of train - assert abs(counts.get("A", 0) - 0.5) < 0.1 - - def test_stratified_no_overlap(self, sample_df): - splits = create_splits(sample_df, stratify_col="category") - train_ids = set(splits["train"]["id"]) - val_ids = set(splits["val"]["id"]) - test_ids = set(splits["test"]["id"]) - assert train_ids.isdisjoint(val_ids) - assert train_ids.isdisjoint(test_ids) - assert val_ids.isdisjoint(test_ids) - - def test_stratified_reproducible(self, sample_df): - splits1 = create_splits(sample_df, stratify_col="category", seed=42) - splits2 = create_splits(sample_df, stratify_col="category", seed=42) - pd.testing.assert_frame_equal(splits1["train"], splits2["train"]) +from chebi_utils.splitter import create_multilabel_splits @pytest.fixture def multilabel_df(): - """DataFrame with multilabel 'labels' column (200 rows).""" - all_labels = [["A"], ["B"], ["C"], ["A", "B"], ["A", "C"], ["B", "C"]] - labels = [all_labels[i % len(all_labels)] for i in range(200)] + """DataFrame with three bool label columns starting at index 2 (200 rows). + + Column layout: chebi_id | mol | label_A | label_B | label_C + Each row gets one of six label combinations in a round-robin pattern. + """ + patterns = [ + (True, False, False), + (False, True, False), + (False, False, True), + (True, True, False), + (True, False, True), + (False, True, True), + ] + rows = [patterns[i % len(patterns)] for i in range(200)] + label_a, label_b, label_c = zip(*rows) return pd.DataFrame( { - "id": [f"CHEBI:{i}" for i in range(200)], - "labels": labels, + "chebi_id": [f"CHEBI:{i}" for i in range(200)], + "mol": ["mol"] * 200, + "label_A": list(label_a), + "label_B": list(label_b), + "label_C": list(label_c), } ) @pytest.fixture def singlelabel_df(): - """DataFrame with single-label 'labels' column.""" + """DataFrame with a single bool label column at index 2 (200 rows).""" return pd.DataFrame( { - "id": [f"CHEBI:{i}" for i in range(200)], - "labels": [["A"] if i % 2 == 0 else ["B"] for i in range(200)], + "chebi_id": [f"CHEBI:{i}" for i in range(200)], + "mol": ["mol"] * 200, + "label_A": [i % 2 == 0 for i in range(200)], } ) class TestCreateMultilabelSplits: def test_returns_three_splits(self, multilabel_df): - splits = create_multilabel_splits(multilabel_df, labels_col="labels") + splits = create_multilabel_splits(multilabel_df) assert set(splits.keys()) == {"train", "val", "test"} def test_sizes_sum_to_total(self, multilabel_df): - splits = create_multilabel_splits(multilabel_df, labels_col="labels") + splits = create_multilabel_splits(multilabel_df) assert sum(len(v) for v in splits.values()) == len(multilabel_df) def test_no_overlap(self, multilabel_df): - splits = create_multilabel_splits(multilabel_df, labels_col="labels") - train_ids = set(splits["train"]["id"]) - val_ids = set(splits["val"]["id"]) - test_ids = set(splits["test"]["id"]) + splits = create_multilabel_splits(multilabel_df) + train_ids = set(splits["train"]["chebi_id"]) + val_ids = set(splits["val"]["chebi_id"]) + test_ids = set(splits["test"]["chebi_id"]) assert train_ids.isdisjoint(val_ids) assert train_ids.isdisjoint(test_ids) assert val_ids.isdisjoint(test_ids) def test_all_rows_covered(self, multilabel_df): - splits = create_multilabel_splits(multilabel_df, labels_col="labels") - all_ids = set(splits["train"]["id"]) | set(splits["val"]["id"]) | set(splits["test"]["id"]) - assert all_ids == set(multilabel_df["id"]) + splits = create_multilabel_splits(multilabel_df) + all_ids = ( + set(splits["train"]["chebi_id"]) + | set(splits["val"]["chebi_id"]) + | set(splits["test"]["chebi_id"]) + ) + assert all_ids == set(multilabel_df["chebi_id"]) + + def test_label_columns_preserved(self, multilabel_df): + splits = create_multilabel_splits(multilabel_df) + for split in splits.values(): + assert list(split.columns) == list(multilabel_df.columns) def test_reproducible_with_same_seed(self, multilabel_df): - splits1 = create_multilabel_splits(multilabel_df, labels_col="labels", seed=7) - splits2 = create_multilabel_splits(multilabel_df, labels_col="labels", seed=7) + splits1 = create_multilabel_splits(multilabel_df, seed=7) + splits2 = create_multilabel_splits(multilabel_df, seed=7) pd.testing.assert_frame_equal(splits1["train"], splits2["train"]) def test_different_seeds_give_different_splits(self, multilabel_df): - splits1 = create_multilabel_splits(multilabel_df, labels_col="labels", seed=1) - splits2 = create_multilabel_splits(multilabel_df, labels_col="labels", seed=2) - assert not splits1["train"]["id"].equals(splits2["train"]["id"]) + splits1 = create_multilabel_splits(multilabel_df, seed=1) + splits2 = create_multilabel_splits(multilabel_df, seed=2) + assert not splits1["train"]["chebi_id"].equals(splits2["train"]["chebi_id"]) def test_approximate_split_sizes(self, multilabel_df): splits = create_multilabel_splits( - multilabel_df, labels_col="labels", train_ratio=0.8, val_ratio=0.1, test_ratio=0.1 + multilabel_df, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1 ) n = len(multilabel_df) assert abs(len(splits["test"]) - int(n * 0.1)) <= 2 assert abs(len(splits["val"]) - int(n * 0.1)) <= 2 + def test_custom_label_start_col(self, multilabel_df): + # Drop the 'mol' column so labels start at index 1 + df_no_mol = multilabel_df.drop(columns=["mol"]) + splits = create_multilabel_splits(df_no_mol, label_start_col=1) + assert sum(len(v) for v in splits.values()) == len(df_no_mol) + def test_invalid_ratios_raise_error(self, multilabel_df): with pytest.raises(ValueError, match="must equal 1.0"): - create_multilabel_splits( - multilabel_df, labels_col="labels", train_ratio=0.5, val_ratio=0.3, test_ratio=0.3 - ) + create_multilabel_splits(multilabel_df, train_ratio=0.5, val_ratio=0.3, test_ratio=0.3) - def test_missing_labels_col_raises_error(self, multilabel_df): - with pytest.raises(ValueError, match="not found in DataFrame"): - create_multilabel_splits(multilabel_df, labels_col="nonexistent") + def test_out_of_range_label_start_col_raises_error(self, multilabel_df): + with pytest.raises(ValueError, match="out of range"): + create_multilabel_splits(multilabel_df, label_start_col=100) def test_singlelabel_path(self, singlelabel_df): - """Single-label lists should use StratifiedShuffleSplit without error.""" - splits = create_multilabel_splits(singlelabel_df, labels_col="labels") + """Single label column should use StratifiedShuffleSplit without error.""" + splits = create_multilabel_splits(singlelabel_df) assert sum(len(v) for v in splits.values()) == len(singlelabel_df) - train_ids = set(splits["train"]["id"]) - val_ids = set(splits["val"]["id"]) - test_ids = set(splits["test"]["id"]) + train_ids = set(splits["train"]["chebi_id"]) + val_ids = set(splits["val"]["chebi_id"]) + test_ids = set(splits["test"]["chebi_id"]) assert train_ids.isdisjoint(val_ids) assert train_ids.isdisjoint(test_ids) From 39e22ecd671721dc251372fcff6fd31caea04aa9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 27 Feb 2026 13:16:36 +0000 Subject: [PATCH 4/7] Fix ruff formatting in obo_extractor.py and test_obo_extractor.py Co-authored-by: sfluegel05 <43573433+sfluegel05@users.noreply.github.com> --- chebi_utils/obo_extractor.py | 1 + tests/test_obo_extractor.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/chebi_utils/obo_extractor.py b/chebi_utils/obo_extractor.py index 9d59c11..e982c79 100644 --- a/chebi_utils/obo_extractor.py +++ b/chebi_utils/obo_extractor.py @@ -115,6 +115,7 @@ def build_chebi_graph(filepath: str | Path) -> nx.DiGraph: return graph + def get_hierarchy_subgraph(chebi_graph: nx.DiGraph) -> nx.DiGraph: """Subgraph of ChEBI including only edges corresponding to hierarchical relations (is_a). Also removes nodes that are not connected by any is_a edges to other nodes.""" return chebi_graph.edge_subgraph( diff --git a/tests/test_obo_extractor.py b/tests/test_obo_extractor.py index c416fd2..43d8647 100644 --- a/tests/test_obo_extractor.py +++ b/tests/test_obo_extractor.py @@ -41,7 +41,9 @@ def test_node_name_attribute(self): def test_smiles_extracted_from_property_value(self): g = build_chebi_graph(SAMPLE_OBO) - expected = "COc1cc2c3cc1Oc1c(O)c(OC)cc4c1[C@H](Cc1ccc(O)c(c1)Oc1ccc(cc1)C[C@@H]3N(C)CC2)N(C)CC4" + expected = ( + "COc1cc2c3cc1Oc1c(O)c(OC)cc4c1[C@H](Cc1ccc(O)c(c1)Oc1ccc(cc1)C[C@@H]3N(C)CC2)N(C)CC4" + ) assert g.nodes["10"]["smiles"] == expected def test_smiles_none_when_absent(self): From d7c45fba65d1cb74df9fd5464515a9f03697b241 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 27 Feb 2026 14:22:54 +0100 Subject: [PATCH 5/7] reformat ruff --- chebi_utils/obo_extractor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/chebi_utils/obo_extractor.py b/chebi_utils/obo_extractor.py index e982c79..1fc427f 100644 --- a/chebi_utils/obo_extractor.py +++ b/chebi_utils/obo_extractor.py @@ -117,7 +117,8 @@ def build_chebi_graph(filepath: str | Path) -> nx.DiGraph: def get_hierarchy_subgraph(chebi_graph: nx.DiGraph) -> nx.DiGraph: - """Subgraph of ChEBI including only edges corresponding to hierarchical relations (is_a). Also removes nodes that are not connected by any is_a edges to other nodes.""" + """Subgraph of ChEBI including only edges corresponding to hierarchical relations (is_a). + Also removes nodes that are not connected by any is_a edges to other nodes.""" return chebi_graph.edge_subgraph( (u, v) for u, v, d in chebi_graph.edges(data=True) if d.get("relation") == "is_a" ) From 079519380437725dd4dd4ded99ab4d4baaafbd8c Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 27 Feb 2026 15:01:54 +0100 Subject: [PATCH 6/7] fix unit tests --- tests/fixtures/sample.obo | 2 +- tests/test_obo_extractor.py | 12 +++++++----- tests/test_sdf_extractor.py | 14 +++++++------- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/tests/fixtures/sample.obo b/tests/fixtures/sample.obo index dd586c3..afda749 100644 --- a/tests/fixtures/sample.obo +++ b/tests/fixtures/sample.obo @@ -84,4 +84,4 @@ xref: wikipedia.en:Starch {source="wikipedia.en"} is_a: CHEBI:37163 ! glucan relationship: BFO:0000051 CHEBI:28057 ! has part amylopectin relationship: BFO:0000051 CHEBI:28102 ! has part amylose -relationship: RO:0000087 CHEBI:75771 ! has role mouse metabolite \ No newline at end of file +relationship: RO:0000087 CHEBI:75771 ! has role mouse metabolite diff --git a/tests/test_obo_extractor.py b/tests/test_obo_extractor.py index 43d8647..c502ffe 100644 --- a/tests/test_obo_extractor.py +++ b/tests/test_obo_extractor.py @@ -18,9 +18,10 @@ def test_returns_directed_graph(self): assert isinstance(g, nx.DiGraph) def test_correct_number_of_nodes(self): - # CHEBI:27189 is obsolete -> excluded; 3 explicit + 1 implicit (24921) = 4 + # CHEBI:27189 is obsolete -> excluded; + # 4 explicit + 5 implicit (superclasses and relation targets) = 9 g = build_chebi_graph(SAMPLE_OBO) - assert len(g.nodes) == 4 + assert len(g.nodes) == 9 def test_node_ids_are_strings(self): g = build_chebi_graph(SAMPLE_OBO) @@ -28,7 +29,8 @@ def test_node_ids_are_strings(self): def test_expected_nodes_present(self): g = build_chebi_graph(SAMPLE_OBO) - assert set(g.nodes) == {"10", "133004", "22750", "24921"} + assert set(g.nodes) == {"10", "133004", "22750", "24921", + "28017", '75771', '28057', '28102', '37163'} def test_obsolete_term_excluded(self): g = build_chebi_graph(SAMPLE_OBO) @@ -71,8 +73,8 @@ def test_isa_chain(self): def test_total_edge_count(self): g = build_chebi_graph(SAMPLE_OBO) - # 10->133004 (is_a), 133004->22750 (is_a), 22750->24921 (is_a) - assert len(g.edges) == 3 + # 10->133004 (is_a), 133004->22750 (is_a), 22750->24921 (is_a), ... + assert len(g.edges) == 7 def test_xref_lines_do_not_break_parsing(self, tmp_path): obo_with_xrefs = tmp_path / "xref.obo" diff --git a/tests/test_sdf_extractor.py b/tests/test_sdf_extractor.py index b7be8a5..2aaf223 100644 --- a/tests/test_sdf_extractor.py +++ b/tests/test_sdf_extractor.py @@ -24,7 +24,7 @@ def test_chebi_id_column_present(self): def test_chebi_ids_correct(self): df = extract_molecules(SAMPLE_SDF) - assert set(df["chebi_id"]) == {"CHEBI:1", "CHEBI:2"} + assert set(df["chebi_id"]) == {"1", "2"} def test_name_column_present(self): df = extract_molecules(SAMPLE_SDF) @@ -57,8 +57,8 @@ def test_mol_objects_are_rdkit_mol(self): def test_mol_atom_counts(self): df = extract_molecules(SAMPLE_SDF) - row1 = df[df["chebi_id"] == "CHEBI:1"].iloc[0] - row2 = df[df["chebi_id"] == "CHEBI:2"].iloc[0] + row1 = df[df["chebi_id"] == "1"].iloc[0] + row2 = df[df["chebi_id"] == "2"].iloc[0] assert row1["mol"].GetNumAtoms() == 1 # methane: 1 C assert row2["mol"].GetNumAtoms() == 2 # ethane: 2 C @@ -70,7 +70,7 @@ def test_mol_sanitized(self): def test_molecule_properties(self): df = extract_molecules(SAMPLE_SDF) - row = df[df["chebi_id"] == "CHEBI:1"].iloc[0] + row = df[df["chebi_id"] == "1"].iloc[0] assert row["name"] == "compound A" assert row["smiles"] == "C" assert row["formula"] == "CH4" @@ -81,7 +81,7 @@ def test_gzipped_sdf(self, tmp_path): f_out.write(f_in.read()) df = extract_molecules(gz_path) assert len(df) == 2 - assert set(df["chebi_id"]) == {"CHEBI:1", "CHEBI:2"} + assert set(df["chebi_id"]) == {"1", "2"} assert all(isinstance(m, rdchem.Mol) for m in df["mol"]) def test_empty_sdf_returns_empty_dataframe(self, tmp_path): @@ -90,11 +90,11 @@ def test_empty_sdf_returns_empty_dataframe(self, tmp_path): df = extract_molecules(empty_sdf) assert df.empty - def test_unparseable_molblock_gives_none(self, tmp_path, recwarn): + def test_unparseable_molblock_excluded(self, tmp_path, recwarn): bad_sdf = tmp_path / "bad.sdf" bad_sdf.write_text( "bad_mol\n\n 0 0 0 0 0 0 0 0 0 0999 V2000\nM END\n" "> \nCHEBI:99\n\n$$$$\n" ) df = extract_molecules(bad_sdf) - assert df.iloc[0]["mol"] is None + assert len(df) == 0 From 23812017b5cc52a4dda4a5f7d03d562772e1f5bc Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 27 Feb 2026 15:03:01 +0100 Subject: [PATCH 7/7] reformat ruff --- tests/test_obo_extractor.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/test_obo_extractor.py b/tests/test_obo_extractor.py index c502ffe..6935260 100644 --- a/tests/test_obo_extractor.py +++ b/tests/test_obo_extractor.py @@ -29,8 +29,17 @@ def test_node_ids_are_strings(self): def test_expected_nodes_present(self): g = build_chebi_graph(SAMPLE_OBO) - assert set(g.nodes) == {"10", "133004", "22750", "24921", - "28017", '75771', '28057', '28102', '37163'} + assert set(g.nodes) == { + "10", + "133004", + "22750", + "24921", + "28017", + "75771", + "28057", + "28102", + "37163", + } def test_obsolete_term_excluded(self): g = build_chebi_graph(SAMPLE_OBO)