diff --git a/chebi_utils/__init__.py b/chebi_utils/__init__.py index 8da2062..1f3dc03 100644 --- a/chebi_utils/__init__.py +++ b/chebi_utils/__init__.py @@ -2,7 +2,7 @@ from chebi_utils.downloader import download_chebi_obo, download_chebi_sdf from chebi_utils.obo_extractor import build_chebi_graph from chebi_utils.sdf_extractor import extract_molecules -from chebi_utils.splitter import create_splits +from chebi_utils.splitter import create_multilabel_splits __all__ = [ "build_labeled_dataset", @@ -10,5 +10,5 @@ "download_chebi_sdf", "build_chebi_graph", "extract_molecules", - "create_splits", + "create_multilabel_splits", ] diff --git a/chebi_utils/obo_extractor.py b/chebi_utils/obo_extractor.py index dda4901..1fc427f 100644 --- a/chebi_utils/obo_extractor.py +++ b/chebi_utils/obo_extractor.py @@ -116,8 +116,9 @@ def build_chebi_graph(filepath: str | Path) -> nx.DiGraph: return graph -def get_hierarchy_subgraph(chebi_graph: nx.DiGraph) -> nx.Graph: - """Subgraph of ChEBI including only edges corresponding to hierarchical relations (is_a). Also removes nodes that are not connected by any is_a edges to other nodes.""" +def get_hierarchy_subgraph(chebi_graph: nx.DiGraph) -> nx.DiGraph: + """Subgraph of ChEBI including only edges corresponding to hierarchical relations (is_a). + Also removes nodes that are not connected by any is_a edges to other nodes.""" return chebi_graph.edge_subgraph( (u, v) for u, v, d in chebi_graph.edges(data=True) if d.get("relation") == "is_a" ) diff --git a/chebi_utils/splitter.py b/chebi_utils/splitter.py index cd705aa..ead7129 100644 --- a/chebi_utils/splitter.py +++ b/chebi_utils/splitter.py @@ -2,33 +2,43 @@ from __future__ import annotations -import numpy as np import pandas as pd -def create_splits( +def create_multilabel_splits( df: pd.DataFrame, + label_start_col: int = 2, train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, - stratify_col: str | None = None, - seed: int = 42, + seed: int | None = 42, ) -> dict[str, pd.DataFrame]: - """Create stratified train/validation/test splits of a DataFrame. + """Create stratified train/validation/test splits for multilabel DataFrames. + + Columns from index *label_start_col* onwards are treated as binary label + columns (one boolean column per label). The stratification strategy is + chosen automatically based on the number of label columns: + + - More than one label column: ``MultilabelStratifiedShuffleSplit`` from + the ``iterative-stratification`` package. + - Single label column: ``StratifiedShuffleSplit`` from ``scikit-learn``. Parameters ---------- df : pd.DataFrame - Input data to split. + Input data. Columns ``0`` to ``label_start_col - 1`` are treated as + feature/metadata columns; all remaining columns are boolean label + columns. A typical ChEBI DataFrame has columns + ``["chebi_id", "mol", "label1", "label2", ...]``. + label_start_col : int + Index of the first label column (default 2). train_ratio : float Fraction of data for training (default 0.8). val_ratio : float Fraction of data for validation (default 0.1). test_ratio : float Fraction of data for testing (default 0.1). - stratify_col : str or None - Column name to use for stratification. If None, splits are random. - seed : int + seed : int or None Random seed for reproducibility. Returns @@ -40,44 +50,60 @@ def create_splits( Raises ------ ValueError - If the ratios do not sum to 1 or any ratio is outside ``[0, 1]``. + If the ratios do not sum to 1, any ratio is outside ``[0, 1]``, or + *label_start_col* is out of range. """ if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6: raise ValueError("train_ratio + val_ratio + test_ratio must equal 1.0") if any(r < 0 or r > 1 for r in [train_ratio, val_ratio, test_ratio]): raise ValueError("All ratios must be between 0 and 1") + if label_start_col >= len(df.columns): + raise ValueError( + f"label_start_col={label_start_col} is out of range for a DataFrame " + f"with {len(df.columns)} columns" + ) - rng = np.random.default_rng(seed) + from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit + from sklearn.model_selection import StratifiedShuffleSplit - if stratify_col is not None: - return _stratified_split(df, train_ratio, val_ratio, test_ratio, stratify_col, rng) - return _random_split(df, train_ratio, val_ratio, test_ratio, rng) + labels_matrix = df.iloc[:, label_start_col:].values + is_multilabel = labels_matrix.shape[1] > 1 + # StratifiedShuffleSplit requires a 1-D label array + y = labels_matrix if is_multilabel else labels_matrix[:, 0] + df_reset = df.reset_index(drop=True) -def _stratified_split( - df: pd.DataFrame, - train_ratio: float, - val_ratio: float, - test_ratio: float, # noqa: ARG001 - stratify_col: str, - rng: np.random.Generator, -) -> dict[str, pd.DataFrame]: - train_indices: list[int] = [] - val_indices: list[int] = [] - test_indices: list[int] = [] + # ── Step 1: carve out the test set ────────────────────────────────────── + if is_multilabel: + test_splitter = MultilabelStratifiedShuffleSplit( + n_splits=1, test_size=test_ratio, random_state=seed + ) + else: + test_splitter = StratifiedShuffleSplit(n_splits=1, test_size=test_ratio, random_state=seed) + train_val_idx, test_idx = next(test_splitter.split(y, y)) + + df_test = df_reset.iloc[test_idx] + df_trainval = df_reset.iloc[train_val_idx] + + # ── Step 2: split train/val from the remaining data ───────────────────── + y_trainval = y[train_val_idx] + val_ratio_adjusted = val_ratio / (1.0 - test_ratio) - for _, group in df.groupby(stratify_col, sort=False): - group_indices = rng.permutation(np.array(group.index.tolist())) - n = len(group_indices) - n_train = max(1, int(n * train_ratio)) - n_val = max(0, int(n * val_ratio)) + if is_multilabel: + val_splitter = MultilabelStratifiedShuffleSplit( + n_splits=1, test_size=val_ratio_adjusted, random_state=seed + ) + else: + val_splitter = StratifiedShuffleSplit( + n_splits=1, test_size=val_ratio_adjusted, random_state=seed + ) + train_idx_inner, val_idx_inner = next(val_splitter.split(y_trainval, y_trainval)) - train_indices.extend(group_indices[:n_train].tolist()) - val_indices.extend(group_indices[n_train : n_train + n_val].tolist()) - test_indices.extend(group_indices[n_train + n_val :].tolist()) + df_train = df_trainval.iloc[train_idx_inner] + df_val = df_trainval.iloc[val_idx_inner] return { - "train": df.loc[train_indices].reset_index(drop=True), - "val": df.loc[val_indices].reset_index(drop=True), - "test": df.loc[test_indices].reset_index(drop=True), + "train": df_train.reset_index(drop=True), + "val": df_val.reset_index(drop=True), + "test": df_test.reset_index(drop=True), } diff --git a/pyproject.toml b/pyproject.toml index 86b4fe7..4be79f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,10 +11,12 @@ license = { file = "LICENSE" } requires-python = ">=3.10" dependencies = [ "fastobo>=0.14", + "iterative-stratification>=0.1.9", "networkx>=3.0", "numpy>=1.24", "pandas>=2.0", "rdkit>=2022.09", + "scikit-learn>=1.0", "chembl_structure_pipeline>=1.2.4", ] diff --git a/tests/fixtures/sample.obo b/tests/fixtures/sample.obo index dd586c3..afda749 100644 --- a/tests/fixtures/sample.obo +++ b/tests/fixtures/sample.obo @@ -84,4 +84,4 @@ xref: wikipedia.en:Starch {source="wikipedia.en"} is_a: CHEBI:37163 ! glucan relationship: BFO:0000051 CHEBI:28057 ! has part amylopectin relationship: BFO:0000051 CHEBI:28102 ! has part amylose -relationship: RO:0000087 CHEBI:75771 ! has role mouse metabolite \ No newline at end of file +relationship: RO:0000087 CHEBI:75771 ! has role mouse metabolite diff --git a/tests/test_obo_extractor.py b/tests/test_obo_extractor.py index 43d8647..6935260 100644 --- a/tests/test_obo_extractor.py +++ b/tests/test_obo_extractor.py @@ -18,9 +18,10 @@ def test_returns_directed_graph(self): assert isinstance(g, nx.DiGraph) def test_correct_number_of_nodes(self): - # CHEBI:27189 is obsolete -> excluded; 3 explicit + 1 implicit (24921) = 4 + # CHEBI:27189 is obsolete -> excluded; + # 4 explicit + 5 implicit (superclasses and relation targets) = 9 g = build_chebi_graph(SAMPLE_OBO) - assert len(g.nodes) == 4 + assert len(g.nodes) == 9 def test_node_ids_are_strings(self): g = build_chebi_graph(SAMPLE_OBO) @@ -28,7 +29,17 @@ def test_node_ids_are_strings(self): def test_expected_nodes_present(self): g = build_chebi_graph(SAMPLE_OBO) - assert set(g.nodes) == {"10", "133004", "22750", "24921"} + assert set(g.nodes) == { + "10", + "133004", + "22750", + "24921", + "28017", + "75771", + "28057", + "28102", + "37163", + } def test_obsolete_term_excluded(self): g = build_chebi_graph(SAMPLE_OBO) @@ -71,8 +82,8 @@ def test_isa_chain(self): def test_total_edge_count(self): g = build_chebi_graph(SAMPLE_OBO) - # 10->133004 (is_a), 133004->22750 (is_a), 22750->24921 (is_a) - assert len(g.edges) == 3 + # 10->133004 (is_a), 133004->22750 (is_a), 22750->24921 (is_a), ... + assert len(g.edges) == 7 def test_xref_lines_do_not_break_parsing(self, tmp_path): obo_with_xrefs = tmp_path / "xref.obo" diff --git a/tests/test_sdf_extractor.py b/tests/test_sdf_extractor.py index b7be8a5..2aaf223 100644 --- a/tests/test_sdf_extractor.py +++ b/tests/test_sdf_extractor.py @@ -24,7 +24,7 @@ def test_chebi_id_column_present(self): def test_chebi_ids_correct(self): df = extract_molecules(SAMPLE_SDF) - assert set(df["chebi_id"]) == {"CHEBI:1", "CHEBI:2"} + assert set(df["chebi_id"]) == {"1", "2"} def test_name_column_present(self): df = extract_molecules(SAMPLE_SDF) @@ -57,8 +57,8 @@ def test_mol_objects_are_rdkit_mol(self): def test_mol_atom_counts(self): df = extract_molecules(SAMPLE_SDF) - row1 = df[df["chebi_id"] == "CHEBI:1"].iloc[0] - row2 = df[df["chebi_id"] == "CHEBI:2"].iloc[0] + row1 = df[df["chebi_id"] == "1"].iloc[0] + row2 = df[df["chebi_id"] == "2"].iloc[0] assert row1["mol"].GetNumAtoms() == 1 # methane: 1 C assert row2["mol"].GetNumAtoms() == 2 # ethane: 2 C @@ -70,7 +70,7 @@ def test_mol_sanitized(self): def test_molecule_properties(self): df = extract_molecules(SAMPLE_SDF) - row = df[df["chebi_id"] == "CHEBI:1"].iloc[0] + row = df[df["chebi_id"] == "1"].iloc[0] assert row["name"] == "compound A" assert row["smiles"] == "C" assert row["formula"] == "CH4" @@ -81,7 +81,7 @@ def test_gzipped_sdf(self, tmp_path): f_out.write(f_in.read()) df = extract_molecules(gz_path) assert len(df) == 2 - assert set(df["chebi_id"]) == {"CHEBI:1", "CHEBI:2"} + assert set(df["chebi_id"]) == {"1", "2"} assert all(isinstance(m, rdchem.Mol) for m in df["mol"]) def test_empty_sdf_returns_empty_dataframe(self, tmp_path): @@ -90,11 +90,11 @@ def test_empty_sdf_returns_empty_dataframe(self, tmp_path): df = extract_molecules(empty_sdf) assert df.empty - def test_unparseable_molblock_gives_none(self, tmp_path, recwarn): + def test_unparseable_molblock_excluded(self, tmp_path, recwarn): bad_sdf = tmp_path / "bad.sdf" bad_sdf.write_text( "bad_mol\n\n 0 0 0 0 0 0 0 0 0 0999 V2000\nM END\n" "> \nCHEBI:99\n\n$$$$\n" ) df = extract_molecules(bad_sdf) - assert df.iloc[0]["mol"] is None + assert len(df) == 0 diff --git a/tests/test_splitter.py b/tests/test_splitter.py index 10b9434..0f5e470 100644 --- a/tests/test_splitter.py +++ b/tests/test_splitter.py @@ -5,101 +5,119 @@ import pandas as pd import pytest -from chebi_utils.splitter import create_splits +from chebi_utils.splitter import create_multilabel_splits @pytest.fixture -def sample_df(): +def multilabel_df(): + """DataFrame with three bool label columns starting at index 2 (200 rows). + + Column layout: chebi_id | mol | label_A | label_B | label_C + Each row gets one of six label combinations in a round-robin pattern. + """ + patterns = [ + (True, False, False), + (False, True, False), + (False, False, True), + (True, True, False), + (True, False, True), + (False, True, True), + ] + rows = [patterns[i % len(patterns)] for i in range(200)] + label_a, label_b, label_c = zip(*rows) return pd.DataFrame( { - "id": [f"CHEBI:{i}" for i in range(100)], - "category": (["A"] * 50) + (["B"] * 30) + (["C"] * 20), + "chebi_id": [f"CHEBI:{i}" for i in range(200)], + "mol": ["mol"] * 200, + "label_A": list(label_a), + "label_B": list(label_b), + "label_C": list(label_c), } ) -class TestCreateSplitsRandom: - def test_returns_three_splits(self, sample_df): - splits = create_splits(sample_df) - assert set(splits.keys()) == {"train", "val", "test"} - - def test_split_sizes_sum_to_total(self, sample_df): - splits = create_splits(sample_df) - total = sum(len(v) for v in splits.values()) - assert total == len(sample_df) +@pytest.fixture +def singlelabel_df(): + """DataFrame with a single bool label column at index 2 (200 rows).""" + return pd.DataFrame( + { + "chebi_id": [f"CHEBI:{i}" for i in range(200)], + "mol": ["mol"] * 200, + "label_A": [i % 2 == 0 for i in range(200)], + } + ) - def test_default_ratios(self, sample_df): - splits = create_splits(sample_df) - assert len(splits["train"]) == 80 - assert len(splits["val"]) == 10 - assert len(splits["test"]) == 10 - def test_reproducible_with_same_seed(self, sample_df): - splits1 = create_splits(sample_df, seed=0) - splits2 = create_splits(sample_df, seed=0) - pd.testing.assert_frame_equal(splits1["train"], splits2["train"]) +class TestCreateMultilabelSplits: + def test_returns_three_splits(self, multilabel_df): + splits = create_multilabel_splits(multilabel_df) + assert set(splits.keys()) == {"train", "val", "test"} - def test_different_seeds_give_different_splits(self, sample_df): - splits1 = create_splits(sample_df, seed=0) - splits2 = create_splits(sample_df, seed=1) - assert not splits1["train"]["id"].equals(splits2["train"]["id"]) + def test_sizes_sum_to_total(self, multilabel_df): + splits = create_multilabel_splits(multilabel_df) + assert sum(len(v) for v in splits.values()) == len(multilabel_df) - def test_no_overlap_between_splits(self, sample_df): - splits = create_splits(sample_df) - train_ids = set(splits["train"]["id"]) - val_ids = set(splits["val"]["id"]) - test_ids = set(splits["test"]["id"]) + def test_no_overlap(self, multilabel_df): + splits = create_multilabel_splits(multilabel_df) + train_ids = set(splits["train"]["chebi_id"]) + val_ids = set(splits["val"]["chebi_id"]) + test_ids = set(splits["test"]["chebi_id"]) assert train_ids.isdisjoint(val_ids) assert train_ids.isdisjoint(test_ids) assert val_ids.isdisjoint(test_ids) - def test_all_rows_covered(self, sample_df): - splits = create_splits(sample_df) - all_ids = set(splits["train"]["id"]) | set(splits["val"]["id"]) | set(splits["test"]["id"]) - assert all_ids == set(sample_df["id"]) - - def test_custom_ratios(self, sample_df): - splits = create_splits(sample_df, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1) - assert len(splits["train"]) == 70 - assert len(splits["val"]) == 20 - assert len(splits["test"]) == 10 + def test_all_rows_covered(self, multilabel_df): + splits = create_multilabel_splits(multilabel_df) + all_ids = ( + set(splits["train"]["chebi_id"]) + | set(splits["val"]["chebi_id"]) + | set(splits["test"]["chebi_id"]) + ) + assert all_ids == set(multilabel_df["chebi_id"]) + + def test_label_columns_preserved(self, multilabel_df): + splits = create_multilabel_splits(multilabel_df) + for split in splits.values(): + assert list(split.columns) == list(multilabel_df.columns) + + def test_reproducible_with_same_seed(self, multilabel_df): + splits1 = create_multilabel_splits(multilabel_df, seed=7) + splits2 = create_multilabel_splits(multilabel_df, seed=7) + pd.testing.assert_frame_equal(splits1["train"], splits2["train"]) - def test_invalid_ratios_raise_error(self, sample_df): + def test_different_seeds_give_different_splits(self, multilabel_df): + splits1 = create_multilabel_splits(multilabel_df, seed=1) + splits2 = create_multilabel_splits(multilabel_df, seed=2) + assert not splits1["train"]["chebi_id"].equals(splits2["train"]["chebi_id"]) + + def test_approximate_split_sizes(self, multilabel_df): + splits = create_multilabel_splits( + multilabel_df, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1 + ) + n = len(multilabel_df) + assert abs(len(splits["test"]) - int(n * 0.1)) <= 2 + assert abs(len(splits["val"]) - int(n * 0.1)) <= 2 + + def test_custom_label_start_col(self, multilabel_df): + # Drop the 'mol' column so labels start at index 1 + df_no_mol = multilabel_df.drop(columns=["mol"]) + splits = create_multilabel_splits(df_no_mol, label_start_col=1) + assert sum(len(v) for v in splits.values()) == len(df_no_mol) + + def test_invalid_ratios_raise_error(self, multilabel_df): with pytest.raises(ValueError, match="must equal 1.0"): - create_splits(sample_df, train_ratio=0.5, val_ratio=0.3, test_ratio=0.3) - - def test_negative_ratio_raises_error(self, sample_df): - with pytest.raises(ValueError, match="between 0 and 1"): - create_splits(sample_df, train_ratio=-0.1, val_ratio=0.6, test_ratio=0.5) - - -class TestCreateSplitsStratified: - def test_stratified_split_returns_three_splits(self, sample_df): - splits = create_splits(sample_df, stratify_col="category") - assert set(splits.keys()) == {"train", "val", "test"} - - def test_stratified_sizes_sum_to_total(self, sample_df): - splits = create_splits(sample_df, stratify_col="category") - total = sum(len(v) for v in splits.values()) - assert total == len(sample_df) - - def test_stratified_preserves_class_proportions(self, sample_df): - splits = create_splits(sample_df, stratify_col="category") - train = splits["train"] - counts = train["category"].value_counts(normalize=True) - # Category A: 50% of data -> ~50% of train - assert abs(counts.get("A", 0) - 0.5) < 0.1 - - def test_stratified_no_overlap(self, sample_df): - splits = create_splits(sample_df, stratify_col="category") - train_ids = set(splits["train"]["id"]) - val_ids = set(splits["val"]["id"]) - test_ids = set(splits["test"]["id"]) + create_multilabel_splits(multilabel_df, train_ratio=0.5, val_ratio=0.3, test_ratio=0.3) + + def test_out_of_range_label_start_col_raises_error(self, multilabel_df): + with pytest.raises(ValueError, match="out of range"): + create_multilabel_splits(multilabel_df, label_start_col=100) + + def test_singlelabel_path(self, singlelabel_df): + """Single label column should use StratifiedShuffleSplit without error.""" + splits = create_multilabel_splits(singlelabel_df) + assert sum(len(v) for v in splits.values()) == len(singlelabel_df) + train_ids = set(splits["train"]["chebi_id"]) + val_ids = set(splits["val"]["chebi_id"]) + test_ids = set(splits["test"]["chebi_id"]) assert train_ids.isdisjoint(val_ids) assert train_ids.isdisjoint(test_ids) - assert val_ids.isdisjoint(test_ids) - - def test_stratified_reproducible(self, sample_df): - splits1 = create_splits(sample_df, stratify_col="category", seed=42) - splits2 = create_splits(sample_df, stratify_col="category", seed=42) - pd.testing.assert_frame_equal(splits1["train"], splits2["train"])