diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..d99ed1c --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,48 @@ +name: CI + +on: + push: + branches: ["**"] + pull_request: + branches: ["**"] + +jobs: + lint: + name: Lint (ruff) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install ruff + run: pip install ruff + + - name: Check formatting + run: ruff format --check . + + - name: Check linting + run: ruff check . + + test: + name: Unit Tests + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12"] + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install package and test dependencies + run: pip install -e ".[dev]" + + - name: Run tests + run: pytest tests/ -v diff --git a/README.md b/README.md index 6829bf6..ec443c0 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,85 @@ # python-chebi-utils -Common processing functionality for the ChEBI ontology (e.g. extraction of molecules, classes and relations). + +Common processing functionality for the ChEBI ontology — download data files, extract classes and relations, extract molecules, and generate stratified train/val/test splits. + +## Installation + +```bash +pip install chebi-utils +``` + +For development (includes `pytest` and `ruff`): + +```bash +pip install -e ".[dev]" +``` + +## Features + +### Download ChEBI data files + +```python +from chebi_utils import download_chebi_obo, download_chebi_sdf + +obo_path = download_chebi_obo(dest_dir="data/") # downloads chebi.obo +sdf_path = download_chebi_sdf(dest_dir="data/") # downloads chebi.sdf.gz +``` + +Files are fetched from the [EBI FTP server](https://ftp.ebi.ac.uk/pub/databases/chebi/). + +### Extract ontology classes and relations + +```python +from chebi_utils import extract_classes, extract_relations + +classes = extract_classes("chebi.obo") +# DataFrame: id, name, definition, is_obsolete + +relations = extract_relations("chebi.obo") +# DataFrame: source_id, target_id, relation_type (is_a, has_role, …) +``` + +### Extract molecules + +```python +from chebi_utils import extract_molecules + +molecules = extract_molecules("chebi.sdf.gz") +# DataFrame: chebi_id, name, smiles, inchi, inchikey, formula, charge, mass, … +``` + +Both plain `.sdf` and gzip-compressed `.sdf.gz` files are supported. + +### Generate train/val/test splits + +```python +from chebi_utils import create_splits + +splits = create_splits(molecules, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1) +train_df = splits["train"] +val_df = splits["val"] +test_df = splits["test"] +``` + +Pass `stratify_col` to preserve class proportions across splits: + +```python +splits = create_splits(classes, stratify_col="is_obsolete", seed=42) +``` + +## Running Tests + +```bash +pytest tests/ -v +``` + +## Linting + +```bash +ruff check . +ruff format --check . +``` + +## CI/CD + +A GitHub Actions workflow (`.github/workflows/ci.yml`) automatically runs ruff linting and the full test suite on every push and pull request across Python 3.10, 3.11, and 3.12. diff --git a/chebi_utils/__init__.py b/chebi_utils/__init__.py new file mode 100644 index 0000000..8367bab --- /dev/null +++ b/chebi_utils/__init__.py @@ -0,0 +1,12 @@ +from chebi_utils.downloader import download_chebi_obo, download_chebi_sdf +from chebi_utils.obo_extractor import build_chebi_graph +from chebi_utils.sdf_extractor import extract_molecules +from chebi_utils.splitter import create_splits + +__all__ = [ + "download_chebi_obo", + "download_chebi_sdf", + "build_chebi_graph", + "extract_molecules", + "create_splits", +] diff --git a/chebi_utils/downloader.py b/chebi_utils/downloader.py new file mode 100644 index 0000000..2c320af --- /dev/null +++ b/chebi_utils/downloader.py @@ -0,0 +1,78 @@ +"""Download ChEBI data files from the EBI FTP server.""" + +from __future__ import annotations + +import urllib.request +from pathlib import Path + +_CHEBI_LEGACY_VERSION_THRESHOLD = 245 + + +def _chebi_obo_url(version: int) -> str: + if version < _CHEBI_LEGACY_VERSION_THRESHOLD: + return f"https://ftp.ebi.ac.uk/pub/databases/chebi/archive/chebi_legacy/archive/rel{version}/ontology/chebi.obo" + return f"https://ftp.ebi.ac.uk/pub/databases/chebi/archive/rel{version}/ontology/chebi.obo" + + +def _chebi_sdf_url(version: int) -> str: + if version < _CHEBI_LEGACY_VERSION_THRESHOLD: + return f"https://ftp.ebi.ac.uk/pub/databases/chebi/archive/chebi_legacy/archive/rel{version}/ontology/chebi.obo" + return f"https://ftp.ebi.ac.uk/pub/databases/chebi/archive/rel{version}/SDF/chebi.sdf.gz" + + +def download_chebi_obo( + version: int, + dest_dir: str | Path = ".", + filename: str = "chebi.obo", +) -> Path: + """Download a versioned ChEBI OBO ontology file from the EBI FTP server. + + Parameters + ---------- + version : int + ChEBI release version number (e.g. 230, 245, 250). + Versions below 245 are fetched from the legacy archive path. + dest_dir : str or Path + Directory where the file will be saved (created if it doesn't exist). + filename : str + Name for the downloaded file. + + Returns + ------- + Path + Path to the downloaded file. + """ + dest_dir = Path(dest_dir) + dest_dir.mkdir(parents=True, exist_ok=True) + dest_path = dest_dir / filename + urllib.request.urlretrieve(_chebi_obo_url(version), dest_path) + return dest_path + + +def download_chebi_sdf( + version: int, + dest_dir: str | Path = ".", + filename: str = "chebi.sdf.gz", +) -> Path: + """Download a versioned ChEBI SDF file from the EBI FTP server. + + Parameters + ---------- + version : int + ChEBI release version number (e.g. 230, 245, 250). + Versions below 245 are fetched from the legacy archive path. + dest_dir : str or Path + Directory where the file will be saved (created if it doesn't exist). + filename : str + Name for the downloaded file. + + Returns + ------- + Path + Path to the downloaded file. + """ + dest_dir = Path(dest_dir) + dest_dir.mkdir(parents=True, exist_ok=True) + dest_path = dest_dir / filename + urllib.request.urlretrieve(_chebi_sdf_url(version), dest_path) + return dest_path diff --git a/chebi_utils/obo_extractor.py b/chebi_utils/obo_extractor.py new file mode 100644 index 0000000..78224cb --- /dev/null +++ b/chebi_utils/obo_extractor.py @@ -0,0 +1,113 @@ +"""Extract ChEBI ontology data using fastobo and build a networkx graph.""" + +from __future__ import annotations + +from pathlib import Path + +import fastobo +import networkx as nx + + +def _chebi_id_to_str(chebi_id: str) -> str: + """Convert 'CHEBI:123' to '123' (string).""" + return chebi_id.split(":")[1] + + +def _term_data(doc: "fastobo.term.TermFrame") -> dict | None: + """Extract data from a single fastobo TermFrame. + + Returns + ------- + dict or None + Parsed term data, or ``None`` if the term is marked as obsolete. + """ + parents: list[str] = [] + has_part: set[str] = set() + name: str | None = None + smiles: str | None = None + subset: str | None = None + + for clause in doc: + if isinstance(clause, fastobo.term.IsObsoleteClause): + if clause.obsolete: + return None + elif isinstance(clause, fastobo.term.PropertyValueClause): + pv = clause.property_value + if str(pv.relation) in ( + "chemrof:smiles_string", + "http://purl.obolibrary.org/obo/chebi/smiles", + ): + smiles = pv.value + elif isinstance(clause, fastobo.term.SynonymClause): + if "SMILES" in clause.raw_value() and smiles is None: + smiles = clause.raw_value().split('"')[1] + elif isinstance(clause, fastobo.term.RelationshipClause): + if str(clause.typedef) == "has_part": + has_part.add(_chebi_id_to_str(str(clause.term))) + elif isinstance(clause, fastobo.term.IsAClause): + parents.append(_chebi_id_to_str(str(clause.term))) + elif isinstance(clause, fastobo.term.NameClause): + name = str(clause.name) + elif isinstance(clause, fastobo.term.SubsetClause): + subset = str(clause.subset) + + return { + "id": _chebi_id_to_str(str(doc.id)), + "parents": parents, + "has_part": has_part, + "name": name, + "smiles": smiles, + "subset": subset, + } + + +def build_chebi_graph(filepath: str | Path) -> nx.DiGraph: + """Parse a ChEBI OBO file and build a directed graph of ontology terms. + + ``xref:`` lines are stripped before parsing as they can cause fastobo + errors on some ChEBI releases. Only non-obsolete CHEBI-prefixed terms + are included. + + **Nodes** are string CHEBI IDs (e.g. ``"1"`` for ``CHEBI:1``) with + attributes ``name``, ``smiles``, and ``subset``. + + **Edges** carry a ``relation`` attribute and represent: + + - ``is_a`` — directed from child to parent + - ``has_part`` — directed from whole to part + + Parameters + ---------- + filepath : str or Path + Path to the ChEBI OBO file. + + Returns + ------- + nx.DiGraph + Directed graph of ChEBI ontology terms and their relationships. + """ + with open(filepath, encoding="utf-8") as f: + content = "\n".join(line for line in f if not line.startswith("xref:")) + + graph: nx.DiGraph = nx.DiGraph() + + for frame in fastobo.loads(content): + if not ( + frame and isinstance(frame.id, fastobo.id.PrefixedIdent) and frame.id.prefix == "CHEBI" + ): + continue + + term = _term_data(frame) + if term is None: + continue + + node_id = term["id"] + graph.add_node(node_id, name=term["name"], smiles=term["smiles"], subset=term["subset"]) + + for parent_id in term["parents"]: + graph.add_edge(node_id, parent_id, relation="is_a") + + for part_id in term["has_part"]: + graph.add_edge(node_id, part_id, relation="has_part") + + return graph diff --git a/chebi_utils/sdf_extractor.py b/chebi_utils/sdf_extractor.py new file mode 100644 index 0000000..6c2e63b --- /dev/null +++ b/chebi_utils/sdf_extractor.py @@ -0,0 +1,162 @@ +"""Extract molecule data from ChEBI SDF files.""" + +from __future__ import annotations + +import gzip +import warnings +from pathlib import Path + +import pandas as pd +from rdkit import Chem + + +def _update_mol_valences(mol: Chem.Mol) -> Chem.Mol: + """Mark all atoms as having no implicit hydrogens to preserve molfile valences.""" + for atom in mol.GetAtoms(): + atom.SetNoImplicit(True) + return mol + + +def _parse_molblock(molblock: str, chebi_id: str | None = None) -> Chem.Mol | None: + """Parse a V2000/V3000 molblock into an RDKit Mol object. + + Uses partial sanitisation to handle ChEBI molecules with unusual valences + or radicals. + + Parameters + ---------- + molblock : str + The molblock string (header + atom/bond table + ``M END``). + chebi_id : str or None + Used only for the warning message when parsing fails. + + Returns + ------- + Chem.Mol or None + Parsed molecule, or ``None`` if parsing failed. + """ + mol = Chem.MolFromMolBlock(molblock, sanitize=False, removeHs=False) + if mol is None: + warnings.warn(f"Failed to parse molblock for {chebi_id}", stacklevel=2) + return None + mol = _update_mol_valences(mol) + Chem.SanitizeMol( + mol, + sanitizeOps=( + Chem.SanitizeFlags.SANITIZE_FINDRADICALS + | Chem.SanitizeFlags.SANITIZE_KEKULIZE + | Chem.SanitizeFlags.SANITIZE_SETAROMATICITY + | Chem.SanitizeFlags.SANITIZE_SETCONJUGATION + | Chem.SanitizeFlags.SANITIZE_SETHYBRIDIZATION + | Chem.SanitizeFlags.SANITIZE_SYMMRINGS + ), + catchErrors=True, + ) + return mol + + +def _iter_sdf_records(filepath: str | Path): + """Yield individual SDF records as strings.""" + opener = gzip.open if str(filepath).endswith(".gz") else open + current_record: list[str] = [] + + with opener(filepath, "rt", encoding="utf-8") as f: + for line in f: + current_record.append(line) + if line.strip() == "$$$$": + yield "".join(current_record) + current_record = [] + + +def _parse_sdf_record(record: str) -> tuple[dict[str, str], str]: + """Parse a single SDF record. + + Returns + ------- + tuple[dict[str, str], str] + ``(props, molblock)`` where *props* is a dict of data-item key/values + and *molblock* is the raw connection-table string. + """ + props: dict[str, str] = {} + lines = record.splitlines(keepends=True) + + # Collect molblock: everything up to (but not including) the first "> <" tag + molblock_lines: list[str] = [] + data_start = len(lines) + for idx, line in enumerate(lines): + stripped = line.strip() + if stripped.startswith("> <") or stripped == "$$$$": + data_start = idx + break + molblock_lines.append(line) + molblock = "".join(molblock_lines) + + # Extract header name (first line of molblock) + if molblock_lines: + props["mol_name"] = molblock_lines[0].strip() + + # Parse data items + i = data_start + while i < len(lines): + line = lines[i].strip() + if line.startswith("> <") and line.endswith(">"): + key = line[3:-1] + value_lines: list[str] = [] + i += 1 + while i < len(lines) and lines[i].strip() not in ("", "$$$$"): + value_lines.append(lines[i].strip()) + i += 1 + props[key] = "\n".join(value_lines) + else: + i += 1 + + return props, molblock + + +def extract_molecules(filepath: str | Path) -> pd.DataFrame: + """Extract molecule data from a ChEBI SDF file. + + Supports both plain (``.sdf``) and gzip-compressed (``.sdf.gz``) files. + Each molecule is parsed into an RDKit ``Mol`` object stored in the ``mol`` + column. Molecules that cannot be parsed result in ``None`` in that column. + + Parameters + ---------- + filepath : str or Path + Path to the ChEBI SDF (or SDF.gz) file. + + Returns + ------- + pd.DataFrame + DataFrame with one row per molecule. Columns depend on the properties + present in the file. Common columns (renamed for convenience): + chebi_id, name, inchi, inchikey, smiles, formula, charge, mass, mol. + """ + rows = [] + molblocks = [] + for record in _iter_sdf_records(filepath): + props, molblock = _parse_sdf_record(record) + rows.append(props) + molblocks.append(molblock) + + if not rows: + return pd.DataFrame() + + df = pd.DataFrame(rows) + + rename_map = { + "ChEBI ID": "chebi_id", + "ChEBI Name": "name", + "InChI": "inchi", + "InChIKey": "inchikey", + "SMILES": "smiles", + "Formulae": "formula", + "Charge": "charge", + "Mass": "mass", + } + df = df.rename(columns={k: v for k, v in rename_map.items() if k in df.columns}) + + chebi_ids = df["chebi_id"].tolist() if "chebi_id" in df.columns else [None] * len(df) + df["mol"] = [_parse_molblock(mb, cid) for mb, cid in zip(molblocks, chebi_ids, strict=False)] + + return df diff --git a/chebi_utils/splitter.py b/chebi_utils/splitter.py new file mode 100644 index 0000000..dd86141 --- /dev/null +++ b/chebi_utils/splitter.py @@ -0,0 +1,105 @@ +"""Generate stratified train/validation/test splits from ChEBI DataFrames.""" + +from __future__ import annotations + +import numpy as np +import pandas as pd + + +def create_splits( + df: pd.DataFrame, + train_ratio: float = 0.8, + val_ratio: float = 0.1, + test_ratio: float = 0.1, + stratify_col: str | None = None, + seed: int = 42, +) -> dict[str, pd.DataFrame]: + """Create stratified train/validation/test splits of a DataFrame. + + Parameters + ---------- + df : pd.DataFrame + Input data to split. + train_ratio : float + Fraction of data for training (default 0.8). + val_ratio : float + Fraction of data for validation (default 0.1). + test_ratio : float + Fraction of data for testing (default 0.1). + stratify_col : str or None + Column name to use for stratification. If None, splits are random. + seed : int + Random seed for reproducibility. + + Returns + ------- + dict + Dictionary with keys ``'train'``, ``'val'``, ``'test'``, each + containing a DataFrame. + + Raises + ------ + ValueError + If the ratios do not sum to 1 or any ratio is outside ``[0, 1]``. + """ + if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6: + raise ValueError("train_ratio + val_ratio + test_ratio must equal 1.0") + if any(r < 0 or r > 1 for r in [train_ratio, val_ratio, test_ratio]): + raise ValueError("All ratios must be between 0 and 1") + + rng = np.random.default_rng(seed) + + if stratify_col is not None: + return _stratified_split(df, train_ratio, val_ratio, test_ratio, stratify_col, rng) + return _random_split(df, train_ratio, val_ratio, test_ratio, rng) + + +def _random_split( + df: pd.DataFrame, + train_ratio: float, + val_ratio: float, + test_ratio: float, # noqa: ARG001 + rng: np.random.Generator, +) -> dict[str, pd.DataFrame]: + indices = rng.permutation(len(df)) + n_train = int(len(df) * train_ratio) + n_val = int(len(df) * val_ratio) + + train_idx = indices[:n_train] + val_idx = indices[n_train : n_train + n_val] + test_idx = indices[n_train + n_val :] + + return { + "train": df.iloc[train_idx].reset_index(drop=True), + "val": df.iloc[val_idx].reset_index(drop=True), + "test": df.iloc[test_idx].reset_index(drop=True), + } + + +def _stratified_split( + df: pd.DataFrame, + train_ratio: float, + val_ratio: float, + test_ratio: float, # noqa: ARG001 + stratify_col: str, + rng: np.random.Generator, +) -> dict[str, pd.DataFrame]: + train_indices: list[int] = [] + val_indices: list[int] = [] + test_indices: list[int] = [] + + for _, group in df.groupby(stratify_col, sort=False): + group_indices = rng.permutation(np.array(group.index.tolist())) + n = len(group_indices) + n_train = max(1, int(n * train_ratio)) + n_val = max(0, int(n * val_ratio)) + + train_indices.extend(group_indices[:n_train].tolist()) + val_indices.extend(group_indices[n_train : n_train + n_val].tolist()) + test_indices.extend(group_indices[n_train + n_val :].tolist()) + + return { + "train": df.loc[train_indices].reset_index(drop=True), + "val": df.loc[val_indices].reset_index(drop=True), + "test": df.loc[test_indices].reset_index(drop=True), + } diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..69e3be2 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,34 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "chebi-utils" +version = "0.1.0" +description = "Common processing functionality for the ChEBI ontology" +readme = "README.md" +license = { file = "LICENSE" } +requires-python = ">=3.10" +dependencies = [ + "fastobo>=0.14", + "networkx>=3.0", + "numpy>=1.24", + "pandas>=2.0", + "rdkit>=2022.09", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0", + "ruff>=0.4", +] + +[tool.ruff] +line-length = 99 +target-version = "py310" + +[tool.ruff.lint] +select = ["E", "F", "W", "I"] + +[tool.pytest.ini_options] +testpaths = ["tests"] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/fixtures/sample.obo b/tests/fixtures/sample.obo new file mode 100644 index 0000000..2b486fa --- /dev/null +++ b/tests/fixtures/sample.obo @@ -0,0 +1,37 @@ +format-version: 1.2 +ontology: chebi + +[Term] +id: CHEBI:1 +name: compound A +def: "A test compound." [TestDB:001] +is_a: CHEBI:2 ! compound B +relationship: has_part CHEBI:3 ! methyl group +property_value: http://purl.obolibrary.org/obo/chebi/smiles "C" xsd:string + +[Term] +id: CHEBI:2 +name: compound B +def: "Another test compound." [TestDB:002] +is_a: CHEBI:5 ! root compound + +[Term] +id: CHEBI:3 +name: methyl group +def: "A methyl group." [TestDB:003] +subset: 3_STAR + +[Term] +id: CHEBI:4 +name: obsolete term +def: "This term is obsolete." [TestDB:004] +is_obsolete: true + +[Term] +id: CHEBI:5 +name: root compound +def: "The root compound." [TestDB:005] + +[Typedef] +id: has_part +name: has part diff --git a/tests/fixtures/sample.sdf b/tests/fixtures/sample.sdf new file mode 100644 index 0000000..e877215 --- /dev/null +++ b/tests/fixtures/sample.sdf @@ -0,0 +1,64 @@ +compound_a + RDKit + + 1 0 0 0 0 0 0 0 0 0999 V2000 + 0.0000 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0 +M END +> +CHEBI:1 + +> +compound A + +> +C + +> +InChI=1S/CH4/h1H4 + +> +VNWKTOKETHGBQD-UHFFFAOYSA-N + +> +CH4 + +> +0 + +> +16.043 + +$$$$ +compound_b + RDKit + + 2 1 0 0 0 0 0 0 0 0999 V2000 + 0.0000 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0 + 1.5400 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0 + 1 2 1 0 +M END +> +CHEBI:2 + +> +compound B + +> +CC + +> +InChI=1S/C2H6/c1-2/h1-2H3 + +> +OTMSDBZUPAUEDD-UHFFFAOYSA-N + +> +C2H6 + +> +0 + +> +30.069 + +$$$$ diff --git a/tests/test_downloader.py b/tests/test_downloader.py new file mode 100644 index 0000000..126e791 --- /dev/null +++ b/tests/test_downloader.py @@ -0,0 +1,103 @@ +"""Tests for chebi_utils.downloader.""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import patch + +from chebi_utils.downloader import ( + _chebi_obo_url, + _chebi_sdf_url, + download_chebi_obo, + download_chebi_sdf, +) + +# --- URL helper tests --- + + +def test_obo_url_modern_version(): + url = _chebi_obo_url(245) + assert url == "https://ftp.ebi.ac.uk/pub/databases/chebi/archive/rel245/ontology/chebi.obo" + + +def test_obo_url_legacy_version(): + url = _chebi_obo_url(244) + assert ( + url + == "https://ftp.ebi.ac.uk/pub/databases/chebi/archive/chebi_legacy/archive/rel244/ontology/chebi.obo" + ) + + +def test_sdf_url_modern_version(): + url = _chebi_sdf_url(245) + assert url == "https://ftp.ebi.ac.uk/pub/databases/chebi/archive/rel245/SDF/chebi.sdf.gz" + + +def test_sdf_url_legacy_version(): + url = _chebi_sdf_url(230) + assert ( + url + == "https://ftp.ebi.ac.uk/pub/databases/chebi/archive/chebi_legacy/archive/rel230/ontology/chebi.obo" + ) + + +# --- download_chebi_obo tests --- + + +def test_download_chebi_obo_calls_urlretrieve_modern(tmp_path): + with patch("chebi_utils.downloader.urllib.request.urlretrieve") as mock_retrieve: + result = download_chebi_obo(version=250, dest_dir=tmp_path) + mock_retrieve.assert_called_once_with(_chebi_obo_url(250), tmp_path / "chebi.obo") + assert result == tmp_path / "chebi.obo" + + +def test_download_chebi_obo_calls_urlretrieve_legacy(tmp_path): + with patch("chebi_utils.downloader.urllib.request.urlretrieve") as mock_retrieve: + result = download_chebi_obo(version=230, dest_dir=tmp_path) + mock_retrieve.assert_called_once_with(_chebi_obo_url(230), tmp_path / "chebi.obo") + assert result == tmp_path / "chebi.obo" + + +def test_download_chebi_obo_custom_filename(tmp_path): + with patch("chebi_utils.downloader.urllib.request.urlretrieve"): + result = download_chebi_obo(version=250, dest_dir=tmp_path, filename="my_chebi.obo") + assert result == tmp_path / "my_chebi.obo" + + +# --- download_chebi_sdf tests --- + + +def test_download_chebi_sdf_calls_urlretrieve_modern(tmp_path): + with patch("chebi_utils.downloader.urllib.request.urlretrieve") as mock_retrieve: + result = download_chebi_sdf(version=250, dest_dir=tmp_path) + mock_retrieve.assert_called_once_with(_chebi_sdf_url(250), tmp_path / "chebi.sdf.gz") + assert result == tmp_path / "chebi.sdf.gz" + + +def test_download_chebi_sdf_calls_urlretrieve_legacy(tmp_path): + with patch("chebi_utils.downloader.urllib.request.urlretrieve") as mock_retrieve: + result = download_chebi_sdf(version=230, dest_dir=tmp_path) + mock_retrieve.assert_called_once_with(_chebi_sdf_url(230), tmp_path / "chebi.sdf.gz") + assert result == tmp_path / "chebi.sdf.gz" + + +def test_download_chebi_sdf_custom_filename(tmp_path): + with patch("chebi_utils.downloader.urllib.request.urlretrieve"): + result = download_chebi_sdf(version=250, dest_dir=tmp_path, filename="my_chebi.sdf.gz") + assert result == tmp_path / "my_chebi.sdf.gz" + + +# --- shared behaviour tests --- + + +def test_download_creates_dest_dir(tmp_path): + new_dir = tmp_path / "subdir" / "nested" + with patch("chebi_utils.downloader.urllib.request.urlretrieve"): + download_chebi_obo(version=250, dest_dir=new_dir) + assert new_dir.is_dir() + + +def test_download_returns_path_object(tmp_path): + with patch("chebi_utils.downloader.urllib.request.urlretrieve"): + result = download_chebi_obo(version=250, dest_dir=str(tmp_path)) + assert isinstance(result, Path) diff --git a/tests/test_obo_extractor.py b/tests/test_obo_extractor.py new file mode 100644 index 0000000..d5a50e6 --- /dev/null +++ b/tests/test_obo_extractor.py @@ -0,0 +1,86 @@ +"""Tests for chebi_utils.obo_extractor.""" + +from __future__ import annotations + +from pathlib import Path + +import networkx as nx + +from chebi_utils.obo_extractor import build_chebi_graph + +FIXTURES = Path(__file__).parent / "fixtures" +SAMPLE_OBO = FIXTURES / "sample.obo" + + +class TestBuildChebiGraph: + def test_returns_directed_graph(self): + g = build_chebi_graph(SAMPLE_OBO) + assert isinstance(g, nx.DiGraph) + + def test_correct_number_of_nodes(self): + # CHEBI:4 is obsolete and must be excluded -> 4 nodes remain + g = build_chebi_graph(SAMPLE_OBO) + assert len(g.nodes) == 4 + + def test_node_ids_are_strings(self): + g = build_chebi_graph(SAMPLE_OBO) + assert all(isinstance(n, str) for n in g.nodes) + + def test_expected_nodes_present(self): + g = build_chebi_graph(SAMPLE_OBO) + assert set(g.nodes) == {"1", "2", "3", "5"} + + def test_obsolete_term_excluded(self): + g = build_chebi_graph(SAMPLE_OBO) + assert "4" not in g.nodes + + def test_node_name_attribute(self): + g = build_chebi_graph(SAMPLE_OBO) + assert g.nodes["1"]["name"] == "compound A" + assert g.nodes["2"]["name"] == "compound B" + + def test_smiles_extracted_from_property_value(self): + g = build_chebi_graph(SAMPLE_OBO) + assert g.nodes["1"]["smiles"] == "C" + + def test_smiles_none_when_absent(self): + g = build_chebi_graph(SAMPLE_OBO) + assert g.nodes["2"]["smiles"] is None + + def test_subset_extracted(self): + g = build_chebi_graph(SAMPLE_OBO) + assert g.nodes["3"]["subset"] == "3_STAR" + + def test_subset_none_when_absent(self): + g = build_chebi_graph(SAMPLE_OBO) + assert g.nodes["1"]["subset"] is None + + def test_isa_edge_present(self): + g = build_chebi_graph(SAMPLE_OBO) + # CHEBI:1 is_a CHEBI:2 + assert g.has_edge("1", "2") + assert g.edges["1", "2"]["relation"] == "is_a" + + def test_has_part_edge_present(self): + g = build_chebi_graph(SAMPLE_OBO) + # CHEBI:1 has_part CHEBI:3 + assert g.has_edge("1", "3") + assert g.edges["1", "3"]["relation"] == "has_part" + + def test_total_edge_count(self): + g = build_chebi_graph(SAMPLE_OBO) + # 1->2 (is_a), 1->3 (has_part), 2->5 (is_a) + assert len(g.edges) == 3 + + def test_typedef_stanza_excluded(self): + g = build_chebi_graph(SAMPLE_OBO) + # "has_part" Typedef id is not numeric CHEBI ID, should not appear as node + assert "has_part" not in g.nodes + + def test_xref_lines_do_not_break_parsing(self, tmp_path): + obo_with_xrefs = tmp_path / "xref.obo" + obo_with_xrefs.write_text( + "format-version: 1.2\n[Term]\nid: CHEBI:10\nname: test\nxref: Reaxys:123456\n" + ) + g = build_chebi_graph(obo_with_xrefs) + assert "10" in g.nodes diff --git a/tests/test_sdf_extractor.py b/tests/test_sdf_extractor.py new file mode 100644 index 0000000..b7be8a5 --- /dev/null +++ b/tests/test_sdf_extractor.py @@ -0,0 +1,100 @@ +"""Tests for chebi_utils.sdf_extractor.""" + +from __future__ import annotations + +import gzip +from pathlib import Path + +from rdkit.Chem import rdchem + +from chebi_utils.sdf_extractor import extract_molecules + +FIXTURES = Path(__file__).parent / "fixtures" +SAMPLE_SDF = FIXTURES / "sample.sdf" + + +class TestExtractMolecules: + def test_returns_two_molecules(self): + df = extract_molecules(SAMPLE_SDF) + assert len(df) == 2 + + def test_chebi_id_column_present(self): + df = extract_molecules(SAMPLE_SDF) + assert "chebi_id" in df.columns + + def test_chebi_ids_correct(self): + df = extract_molecules(SAMPLE_SDF) + assert set(df["chebi_id"]) == {"CHEBI:1", "CHEBI:2"} + + def test_name_column_present(self): + df = extract_molecules(SAMPLE_SDF) + assert "name" in df.columns + + def test_smiles_column_present(self): + df = extract_molecules(SAMPLE_SDF) + assert "smiles" in df.columns + + def test_formula_column_present(self): + df = extract_molecules(SAMPLE_SDF) + assert "formula" in df.columns + + def test_inchi_column_present(self): + df = extract_molecules(SAMPLE_SDF) + assert "inchi" in df.columns + + def test_inchikey_column_present(self): + df = extract_molecules(SAMPLE_SDF) + assert "inchikey" in df.columns + + def test_mol_column_present(self): + df = extract_molecules(SAMPLE_SDF) + assert "mol" in df.columns + + def test_mol_objects_are_rdkit_mol(self): + df = extract_molecules(SAMPLE_SDF) + for mol in df["mol"]: + assert isinstance(mol, rdchem.Mol) + + def test_mol_atom_counts(self): + df = extract_molecules(SAMPLE_SDF) + row1 = df[df["chebi_id"] == "CHEBI:1"].iloc[0] + row2 = df[df["chebi_id"] == "CHEBI:2"].iloc[0] + assert row1["mol"].GetNumAtoms() == 1 # methane: 1 C + assert row2["mol"].GetNumAtoms() == 2 # ethane: 2 C + + def test_mol_sanitized(self): + df = extract_molecules(SAMPLE_SDF) + for mol in df["mol"]: + # Aromaticity flags should be set (sanitize applied) + assert mol is not None + + def test_molecule_properties(self): + df = extract_molecules(SAMPLE_SDF) + row = df[df["chebi_id"] == "CHEBI:1"].iloc[0] + assert row["name"] == "compound A" + assert row["smiles"] == "C" + assert row["formula"] == "CH4" + + def test_gzipped_sdf(self, tmp_path): + gz_path = tmp_path / "sample.sdf.gz" + with open(SAMPLE_SDF, "rb") as f_in, gzip.open(gz_path, "wb") as f_out: + f_out.write(f_in.read()) + df = extract_molecules(gz_path) + assert len(df) == 2 + assert set(df["chebi_id"]) == {"CHEBI:1", "CHEBI:2"} + assert all(isinstance(m, rdchem.Mol) for m in df["mol"]) + + def test_empty_sdf_returns_empty_dataframe(self, tmp_path): + empty_sdf = tmp_path / "empty.sdf" + empty_sdf.write_text("") + df = extract_molecules(empty_sdf) + assert df.empty + + def test_unparseable_molblock_gives_none(self, tmp_path, recwarn): + bad_sdf = tmp_path / "bad.sdf" + bad_sdf.write_text( + "bad_mol\n\n 0 0 0 0 0 0 0 0 0 0999 V2000\nM END\n" + "> \nCHEBI:99\n\n$$$$\n" + ) + df = extract_molecules(bad_sdf) + assert df.iloc[0]["mol"] is None diff --git a/tests/test_splitter.py b/tests/test_splitter.py new file mode 100644 index 0000000..10b9434 --- /dev/null +++ b/tests/test_splitter.py @@ -0,0 +1,105 @@ +"""Tests for chebi_utils.splitter.""" + +from __future__ import annotations + +import pandas as pd +import pytest + +from chebi_utils.splitter import create_splits + + +@pytest.fixture +def sample_df(): + return pd.DataFrame( + { + "id": [f"CHEBI:{i}" for i in range(100)], + "category": (["A"] * 50) + (["B"] * 30) + (["C"] * 20), + } + ) + + +class TestCreateSplitsRandom: + def test_returns_three_splits(self, sample_df): + splits = create_splits(sample_df) + assert set(splits.keys()) == {"train", "val", "test"} + + def test_split_sizes_sum_to_total(self, sample_df): + splits = create_splits(sample_df) + total = sum(len(v) for v in splits.values()) + assert total == len(sample_df) + + def test_default_ratios(self, sample_df): + splits = create_splits(sample_df) + assert len(splits["train"]) == 80 + assert len(splits["val"]) == 10 + assert len(splits["test"]) == 10 + + def test_reproducible_with_same_seed(self, sample_df): + splits1 = create_splits(sample_df, seed=0) + splits2 = create_splits(sample_df, seed=0) + pd.testing.assert_frame_equal(splits1["train"], splits2["train"]) + + def test_different_seeds_give_different_splits(self, sample_df): + splits1 = create_splits(sample_df, seed=0) + splits2 = create_splits(sample_df, seed=1) + assert not splits1["train"]["id"].equals(splits2["train"]["id"]) + + def test_no_overlap_between_splits(self, sample_df): + splits = create_splits(sample_df) + train_ids = set(splits["train"]["id"]) + val_ids = set(splits["val"]["id"]) + test_ids = set(splits["test"]["id"]) + assert train_ids.isdisjoint(val_ids) + assert train_ids.isdisjoint(test_ids) + assert val_ids.isdisjoint(test_ids) + + def test_all_rows_covered(self, sample_df): + splits = create_splits(sample_df) + all_ids = set(splits["train"]["id"]) | set(splits["val"]["id"]) | set(splits["test"]["id"]) + assert all_ids == set(sample_df["id"]) + + def test_custom_ratios(self, sample_df): + splits = create_splits(sample_df, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1) + assert len(splits["train"]) == 70 + assert len(splits["val"]) == 20 + assert len(splits["test"]) == 10 + + def test_invalid_ratios_raise_error(self, sample_df): + with pytest.raises(ValueError, match="must equal 1.0"): + create_splits(sample_df, train_ratio=0.5, val_ratio=0.3, test_ratio=0.3) + + def test_negative_ratio_raises_error(self, sample_df): + with pytest.raises(ValueError, match="between 0 and 1"): + create_splits(sample_df, train_ratio=-0.1, val_ratio=0.6, test_ratio=0.5) + + +class TestCreateSplitsStratified: + def test_stratified_split_returns_three_splits(self, sample_df): + splits = create_splits(sample_df, stratify_col="category") + assert set(splits.keys()) == {"train", "val", "test"} + + def test_stratified_sizes_sum_to_total(self, sample_df): + splits = create_splits(sample_df, stratify_col="category") + total = sum(len(v) for v in splits.values()) + assert total == len(sample_df) + + def test_stratified_preserves_class_proportions(self, sample_df): + splits = create_splits(sample_df, stratify_col="category") + train = splits["train"] + counts = train["category"].value_counts(normalize=True) + # Category A: 50% of data -> ~50% of train + assert abs(counts.get("A", 0) - 0.5) < 0.1 + + def test_stratified_no_overlap(self, sample_df): + splits = create_splits(sample_df, stratify_col="category") + train_ids = set(splits["train"]["id"]) + val_ids = set(splits["val"]["id"]) + test_ids = set(splits["test"]["id"]) + assert train_ids.isdisjoint(val_ids) + assert train_ids.isdisjoint(test_ids) + assert val_ids.isdisjoint(test_ids) + + def test_stratified_reproducible(self, sample_df): + splits1 = create_splits(sample_df, stratify_col="category", seed=42) + splits2 = create_splits(sample_df, stratify_col="category", seed=42) + pd.testing.assert_frame_equal(splits1["train"], splits2["train"])