Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions chebi_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from chebi_utils.downloader import download_chebi_obo, download_chebi_sdf
from chebi_utils.obo_extractor import build_chebi_graph
from chebi_utils.sdf_extractor import extract_molecules
from chebi_utils.splitter import create_splits
from chebi_utils.splitter import create_multilabel_splits

__all__ = [
"build_labeled_dataset",
"download_chebi_obo",
"download_chebi_sdf",
"build_chebi_graph",
"extract_molecules",
"create_splits",
"create_multilabel_splits",
]
5 changes: 3 additions & 2 deletions chebi_utils/obo_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,9 @@ def build_chebi_graph(filepath: str | Path) -> nx.DiGraph:
return graph


def get_hierarchy_subgraph(chebi_graph: nx.DiGraph) -> nx.Graph:
"""Subgraph of ChEBI including only edges corresponding to hierarchical relations (is_a). Also removes nodes that are not connected by any is_a edges to other nodes."""
def get_hierarchy_subgraph(chebi_graph: nx.DiGraph) -> nx.DiGraph:
"""Subgraph of ChEBI including only edges corresponding to hierarchical relations (is_a).
Also removes nodes that are not connected by any is_a edges to other nodes."""
return chebi_graph.edge_subgraph(
(u, v) for u, v, d in chebi_graph.edges(data=True) if d.get("relation") == "is_a"
)
98 changes: 62 additions & 36 deletions chebi_utils/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,43 @@

from __future__ import annotations

import numpy as np
import pandas as pd


def create_splits(
def create_multilabel_splits(
df: pd.DataFrame,
label_start_col: int = 2,
train_ratio: float = 0.8,
val_ratio: float = 0.1,
test_ratio: float = 0.1,
stratify_col: str | None = None,
seed: int = 42,
seed: int | None = 42,
) -> dict[str, pd.DataFrame]:
"""Create stratified train/validation/test splits of a DataFrame.
"""Create stratified train/validation/test splits for multilabel DataFrames.

Columns from index *label_start_col* onwards are treated as binary label
columns (one boolean column per label). The stratification strategy is
chosen automatically based on the number of label columns:

- More than one label column: ``MultilabelStratifiedShuffleSplit`` from
the ``iterative-stratification`` package.
- Single label column: ``StratifiedShuffleSplit`` from ``scikit-learn``.

Parameters
----------
df : pd.DataFrame
Input data to split.
Input data. Columns ``0`` to ``label_start_col - 1`` are treated as
feature/metadata columns; all remaining columns are boolean label
columns. A typical ChEBI DataFrame has columns
``["chebi_id", "mol", "label1", "label2", ...]``.
label_start_col : int
Index of the first label column (default 2).
train_ratio : float
Fraction of data for training (default 0.8).
val_ratio : float
Fraction of data for validation (default 0.1).
test_ratio : float
Fraction of data for testing (default 0.1).
stratify_col : str or None
Column name to use for stratification. If None, splits are random.
seed : int
seed : int or None
Random seed for reproducibility.

Returns
Expand All @@ -40,44 +50,60 @@ def create_splits(
Raises
------
ValueError
If the ratios do not sum to 1 or any ratio is outside ``[0, 1]``.
If the ratios do not sum to 1, any ratio is outside ``[0, 1]``, or
*label_start_col* is out of range.
"""
if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6:
raise ValueError("train_ratio + val_ratio + test_ratio must equal 1.0")
if any(r < 0 or r > 1 for r in [train_ratio, val_ratio, test_ratio]):
raise ValueError("All ratios must be between 0 and 1")
if label_start_col >= len(df.columns):
raise ValueError(
f"label_start_col={label_start_col} is out of range for a DataFrame "
f"with {len(df.columns)} columns"
)

rng = np.random.default_rng(seed)
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from sklearn.model_selection import StratifiedShuffleSplit

if stratify_col is not None:
return _stratified_split(df, train_ratio, val_ratio, test_ratio, stratify_col, rng)
return _random_split(df, train_ratio, val_ratio, test_ratio, rng)
labels_matrix = df.iloc[:, label_start_col:].values
is_multilabel = labels_matrix.shape[1] > 1
# StratifiedShuffleSplit requires a 1-D label array
y = labels_matrix if is_multilabel else labels_matrix[:, 0]

df_reset = df.reset_index(drop=True)

def _stratified_split(
df: pd.DataFrame,
train_ratio: float,
val_ratio: float,
test_ratio: float, # noqa: ARG001
stratify_col: str,
rng: np.random.Generator,
) -> dict[str, pd.DataFrame]:
train_indices: list[int] = []
val_indices: list[int] = []
test_indices: list[int] = []
# ── Step 1: carve out the test set ──────────────────────────────────────
if is_multilabel:
test_splitter = MultilabelStratifiedShuffleSplit(
n_splits=1, test_size=test_ratio, random_state=seed
)
else:
test_splitter = StratifiedShuffleSplit(n_splits=1, test_size=test_ratio, random_state=seed)
train_val_idx, test_idx = next(test_splitter.split(y, y))

df_test = df_reset.iloc[test_idx]
df_trainval = df_reset.iloc[train_val_idx]

# ── Step 2: split train/val from the remaining data ─────────────────────
y_trainval = y[train_val_idx]
val_ratio_adjusted = val_ratio / (1.0 - test_ratio)

for _, group in df.groupby(stratify_col, sort=False):
group_indices = rng.permutation(np.array(group.index.tolist()))
n = len(group_indices)
n_train = max(1, int(n * train_ratio))
n_val = max(0, int(n * val_ratio))
if is_multilabel:
val_splitter = MultilabelStratifiedShuffleSplit(
n_splits=1, test_size=val_ratio_adjusted, random_state=seed
)
else:
val_splitter = StratifiedShuffleSplit(
n_splits=1, test_size=val_ratio_adjusted, random_state=seed
)
train_idx_inner, val_idx_inner = next(val_splitter.split(y_trainval, y_trainval))

train_indices.extend(group_indices[:n_train].tolist())
val_indices.extend(group_indices[n_train : n_train + n_val].tolist())
test_indices.extend(group_indices[n_train + n_val :].tolist())
df_train = df_trainval.iloc[train_idx_inner]
df_val = df_trainval.iloc[val_idx_inner]

return {
"train": df.loc[train_indices].reset_index(drop=True),
"val": df.loc[val_indices].reset_index(drop=True),
"test": df.loc[test_indices].reset_index(drop=True),
"train": df_train.reset_index(drop=True),
"val": df_val.reset_index(drop=True),
"test": df_test.reset_index(drop=True),
}
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ license = { file = "LICENSE" }
requires-python = ">=3.10"
dependencies = [
"fastobo>=0.14",
"iterative-stratification>=0.1.9",
"networkx>=3.0",
"numpy>=1.24",
"pandas>=2.0",
"rdkit>=2022.09",
"scikit-learn>=1.0",
"chembl_structure_pipeline>=1.2.4",
]

Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/sample.obo
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,4 @@ xref: wikipedia.en:Starch {source="wikipedia.en"}
is_a: CHEBI:37163 ! glucan
relationship: BFO:0000051 CHEBI:28057 ! has part amylopectin
relationship: BFO:0000051 CHEBI:28102 ! has part amylose
relationship: RO:0000087 CHEBI:75771 ! has role mouse metabolite
relationship: RO:0000087 CHEBI:75771 ! has role mouse metabolite
21 changes: 16 additions & 5 deletions tests/test_obo_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,28 @@ def test_returns_directed_graph(self):
assert isinstance(g, nx.DiGraph)

def test_correct_number_of_nodes(self):
# CHEBI:27189 is obsolete -> excluded; 3 explicit + 1 implicit (24921) = 4
# CHEBI:27189 is obsolete -> excluded;
# 4 explicit + 5 implicit (superclasses and relation targets) = 9
g = build_chebi_graph(SAMPLE_OBO)
assert len(g.nodes) == 4
assert len(g.nodes) == 9

def test_node_ids_are_strings(self):
g = build_chebi_graph(SAMPLE_OBO)
assert all(isinstance(n, str) for n in g.nodes)

def test_expected_nodes_present(self):
g = build_chebi_graph(SAMPLE_OBO)
assert set(g.nodes) == {"10", "133004", "22750", "24921"}
assert set(g.nodes) == {
"10",
"133004",
"22750",
"24921",
"28017",
"75771",
"28057",
"28102",
"37163",
}

def test_obsolete_term_excluded(self):
g = build_chebi_graph(SAMPLE_OBO)
Expand Down Expand Up @@ -71,8 +82,8 @@ def test_isa_chain(self):

def test_total_edge_count(self):
g = build_chebi_graph(SAMPLE_OBO)
# 10->133004 (is_a), 133004->22750 (is_a), 22750->24921 (is_a)
assert len(g.edges) == 3
# 10->133004 (is_a), 133004->22750 (is_a), 22750->24921 (is_a), ...
assert len(g.edges) == 7

def test_xref_lines_do_not_break_parsing(self, tmp_path):
obo_with_xrefs = tmp_path / "xref.obo"
Expand Down
14 changes: 7 additions & 7 deletions tests/test_sdf_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_chebi_id_column_present(self):

def test_chebi_ids_correct(self):
df = extract_molecules(SAMPLE_SDF)
assert set(df["chebi_id"]) == {"CHEBI:1", "CHEBI:2"}
assert set(df["chebi_id"]) == {"1", "2"}

def test_name_column_present(self):
df = extract_molecules(SAMPLE_SDF)
Expand Down Expand Up @@ -57,8 +57,8 @@ def test_mol_objects_are_rdkit_mol(self):

def test_mol_atom_counts(self):
df = extract_molecules(SAMPLE_SDF)
row1 = df[df["chebi_id"] == "CHEBI:1"].iloc[0]
row2 = df[df["chebi_id"] == "CHEBI:2"].iloc[0]
row1 = df[df["chebi_id"] == "1"].iloc[0]
row2 = df[df["chebi_id"] == "2"].iloc[0]
assert row1["mol"].GetNumAtoms() == 1 # methane: 1 C
assert row2["mol"].GetNumAtoms() == 2 # ethane: 2 C

Expand All @@ -70,7 +70,7 @@ def test_mol_sanitized(self):

def test_molecule_properties(self):
df = extract_molecules(SAMPLE_SDF)
row = df[df["chebi_id"] == "CHEBI:1"].iloc[0]
row = df[df["chebi_id"] == "1"].iloc[0]
assert row["name"] == "compound A"
assert row["smiles"] == "C"
assert row["formula"] == "CH4"
Expand All @@ -81,7 +81,7 @@ def test_gzipped_sdf(self, tmp_path):
f_out.write(f_in.read())
df = extract_molecules(gz_path)
assert len(df) == 2
assert set(df["chebi_id"]) == {"CHEBI:1", "CHEBI:2"}
assert set(df["chebi_id"]) == {"1", "2"}
assert all(isinstance(m, rdchem.Mol) for m in df["mol"])

def test_empty_sdf_returns_empty_dataframe(self, tmp_path):
Expand All @@ -90,11 +90,11 @@ def test_empty_sdf_returns_empty_dataframe(self, tmp_path):
df = extract_molecules(empty_sdf)
assert df.empty

def test_unparseable_molblock_gives_none(self, tmp_path, recwarn):
def test_unparseable_molblock_excluded(self, tmp_path, recwarn):
bad_sdf = tmp_path / "bad.sdf"
bad_sdf.write_text(
"bad_mol\n\n 0 0 0 0 0 0 0 0 0 0999 V2000\nM END\n"
"> <ChEBI ID>\nCHEBI:99\n\n$$$$\n"
)
df = extract_molecules(bad_sdf)
assert df.iloc[0]["mol"] is None
assert len(df) == 0
Loading