From 84ebd3f2adcb96513f198f048069598178ab9bb5 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 25 Feb 2026 08:46:48 +0000 Subject: [PATCH 1/6] Initial plan From 04aa4d36e1b150f69b8c72d29a23917c4f2e5f7f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 25 Feb 2026 08:52:22 +0000 Subject: [PATCH 2/6] Add chebi_utils library with downloader, extractors, splitter, tests, and CI workflow Co-authored-by: sfluegel05 <43573433+sfluegel05@users.noreply.github.com> --- .github/workflows/ci.yml | 48 ++++++++++++++++ README.md | 85 +++++++++++++++++++++++++++- chebi_utils/__init__.py | 13 +++++ chebi_utils/downloader.py | 53 ++++++++++++++++++ chebi_utils/obo_extractor.py | 103 ++++++++++++++++++++++++++++++++++ chebi_utils/sdf_extractor.py | 85 ++++++++++++++++++++++++++++ chebi_utils/splitter.py | 105 +++++++++++++++++++++++++++++++++++ pyproject.toml | 31 +++++++++++ tests/__init__.py | 0 tests/fixtures/sample.obo | 30 ++++++++++ tests/fixtures/sample.sdf | 58 +++++++++++++++++++ tests/test_downloader.py | 52 +++++++++++++++++ tests/test_obo_extractor.py | 72 ++++++++++++++++++++++++ tests/test_sdf_extractor.py | 66 ++++++++++++++++++++++ tests/test_splitter.py | 105 +++++++++++++++++++++++++++++++++++ 15 files changed, 905 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/ci.yml create mode 100644 chebi_utils/__init__.py create mode 100644 chebi_utils/downloader.py create mode 100644 chebi_utils/obo_extractor.py create mode 100644 chebi_utils/sdf_extractor.py create mode 100644 chebi_utils/splitter.py create mode 100644 pyproject.toml create mode 100644 tests/__init__.py create mode 100644 tests/fixtures/sample.obo create mode 100644 tests/fixtures/sample.sdf create mode 100644 tests/test_downloader.py create mode 100644 tests/test_obo_extractor.py create mode 100644 tests/test_sdf_extractor.py create mode 100644 tests/test_splitter.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..d99ed1c --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,48 @@ +name: CI + +on: + push: + branches: ["**"] + pull_request: + branches: ["**"] + +jobs: + lint: + name: Lint (ruff) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install ruff + run: pip install ruff + + - name: Check formatting + run: ruff format --check . + + - name: Check linting + run: ruff check . + + test: + name: Unit Tests + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12"] + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install package and test dependencies + run: pip install -e ".[dev]" + + - name: Run tests + run: pytest tests/ -v diff --git a/README.md b/README.md index 6829bf6..ec443c0 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,85 @@ # python-chebi-utils -Common processing functionality for the ChEBI ontology (e.g. extraction of molecules, classes and relations). + +Common processing functionality for the ChEBI ontology — download data files, extract classes and relations, extract molecules, and generate stratified train/val/test splits. + +## Installation + +```bash +pip install chebi-utils +``` + +For development (includes `pytest` and `ruff`): + +```bash +pip install -e ".[dev]" +``` + +## Features + +### Download ChEBI data files + +```python +from chebi_utils import download_chebi_obo, download_chebi_sdf + +obo_path = download_chebi_obo(dest_dir="data/") # downloads chebi.obo +sdf_path = download_chebi_sdf(dest_dir="data/") # downloads chebi.sdf.gz +``` + +Files are fetched from the [EBI FTP server](https://ftp.ebi.ac.uk/pub/databases/chebi/). + +### Extract ontology classes and relations + +```python +from chebi_utils import extract_classes, extract_relations + +classes = extract_classes("chebi.obo") +# DataFrame: id, name, definition, is_obsolete + +relations = extract_relations("chebi.obo") +# DataFrame: source_id, target_id, relation_type (is_a, has_role, …) +``` + +### Extract molecules + +```python +from chebi_utils import extract_molecules + +molecules = extract_molecules("chebi.sdf.gz") +# DataFrame: chebi_id, name, smiles, inchi, inchikey, formula, charge, mass, … +``` + +Both plain `.sdf` and gzip-compressed `.sdf.gz` files are supported. + +### Generate train/val/test splits + +```python +from chebi_utils import create_splits + +splits = create_splits(molecules, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1) +train_df = splits["train"] +val_df = splits["val"] +test_df = splits["test"] +``` + +Pass `stratify_col` to preserve class proportions across splits: + +```python +splits = create_splits(classes, stratify_col="is_obsolete", seed=42) +``` + +## Running Tests + +```bash +pytest tests/ -v +``` + +## Linting + +```bash +ruff check . +ruff format --check . +``` + +## CI/CD + +A GitHub Actions workflow (`.github/workflows/ci.yml`) automatically runs ruff linting and the full test suite on every push and pull request across Python 3.10, 3.11, and 3.12. diff --git a/chebi_utils/__init__.py b/chebi_utils/__init__.py new file mode 100644 index 0000000..41a117f --- /dev/null +++ b/chebi_utils/__init__.py @@ -0,0 +1,13 @@ +from chebi_utils.downloader import download_chebi_obo, download_chebi_sdf +from chebi_utils.obo_extractor import extract_classes, extract_relations +from chebi_utils.sdf_extractor import extract_molecules +from chebi_utils.splitter import create_splits + +__all__ = [ + "download_chebi_obo", + "download_chebi_sdf", + "extract_classes", + "extract_relations", + "extract_molecules", + "create_splits", +] diff --git a/chebi_utils/downloader.py b/chebi_utils/downloader.py new file mode 100644 index 0000000..81358eb --- /dev/null +++ b/chebi_utils/downloader.py @@ -0,0 +1,53 @@ +"""Download ChEBI data files from the EBI FTP server.""" + +from __future__ import annotations + +import urllib.request +from pathlib import Path + +CHEBI_OBO_URL = "https://ftp.ebi.ac.uk/pub/databases/chebi/ontology/chebi.obo" +CHEBI_SDF_URL = "https://ftp.ebi.ac.uk/pub/databases/chebi/SDF/ChEBI_complete.sdf.gz" + + +def download_chebi_obo(dest_dir: str | Path = ".", filename: str = "chebi.obo") -> Path: + """Download the ChEBI OBO ontology file from the EBI FTP server. + + Parameters + ---------- + dest_dir : str or Path + Directory where the file will be saved (created if it doesn't exist). + filename : str + Name for the downloaded file. + + Returns + ------- + Path + Path to the downloaded file. + """ + dest_dir = Path(dest_dir) + dest_dir.mkdir(parents=True, exist_ok=True) + dest_path = dest_dir / filename + urllib.request.urlretrieve(CHEBI_OBO_URL, dest_path) + return dest_path + + +def download_chebi_sdf(dest_dir: str | Path = ".", filename: str = "chebi.sdf.gz") -> Path: + """Download the ChEBI SDF file from the EBI FTP server. + + Parameters + ---------- + dest_dir : str or Path + Directory where the file will be saved (created if it doesn't exist). + filename : str + Name for the downloaded file. + + Returns + ------- + Path + Path to the downloaded file. + """ + dest_dir = Path(dest_dir) + dest_dir.mkdir(parents=True, exist_ok=True) + dest_path = dest_dir / filename + urllib.request.urlretrieve(CHEBI_SDF_URL, dest_path) + return dest_path diff --git a/chebi_utils/obo_extractor.py b/chebi_utils/obo_extractor.py new file mode 100644 index 0000000..7279814 --- /dev/null +++ b/chebi_utils/obo_extractor.py @@ -0,0 +1,103 @@ +"""Extract classes and relations from ChEBI OBO ontology files.""" + +from __future__ import annotations + +from pathlib import Path + +import pandas as pd + + +def _parse_obo_stanzas(filepath: str | Path) -> list[dict[str, list[str]]]: + """Parse an OBO file and return a list of stanza dicts.""" + stanzas: list[dict[str, list[str]]] = [] + current_stanza: dict[str, list[str]] | None = None + + with open(filepath, encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line or line.startswith("!"): + continue + if line.startswith("["): + if current_stanza is not None: + stanzas.append(current_stanza) + stanza_type = line.strip("[]") + current_stanza = {"_type": [stanza_type]} + elif current_stanza is not None and ":" in line: + key, _, value = line.partition(":") + current_stanza.setdefault(key.strip(), []).append(value.strip()) + + if current_stanza is not None: + stanzas.append(current_stanza) + + return stanzas + + +def extract_classes(filepath: str | Path) -> pd.DataFrame: + """Extract ontology classes (terms) from a ChEBI OBO file. + + Parameters + ---------- + filepath : str or Path + Path to the ChEBI OBO file. + + Returns + ------- + pd.DataFrame + DataFrame with columns: id, name, definition, is_obsolete. + """ + stanzas = _parse_obo_stanzas(filepath) + rows = [] + for stanza in stanzas: + if stanza.get("_type", [None])[0] != "Term": + continue + row = { + "id": stanza.get("id", [None])[0], + "name": stanza.get("name", [None])[0], + "definition": stanza.get("def", [None])[0], + "is_obsolete": stanza.get("is_obsolete", ["false"])[0] == "true", + } + rows.append(row) + return pd.DataFrame(rows, columns=["id", "name", "definition", "is_obsolete"]) + + +def extract_relations(filepath: str | Path) -> pd.DataFrame: + """Extract class relations from a ChEBI OBO file. + + Parameters + ---------- + filepath : str or Path + Path to the ChEBI OBO file. + + Returns + ------- + pd.DataFrame + DataFrame with columns: source_id, target_id, relation_type. + """ + stanzas = _parse_obo_stanzas(filepath) + rows = [] + + for stanza in stanzas: + if stanza.get("_type", [None])[0] != "Term": + continue + source_id = stanza.get("id", [None])[0] + if source_id is None: + continue + + for is_a_val in stanza.get("is_a", []): + target_id = is_a_val.split("!")[0].strip() + rows.append({"source_id": source_id, "target_id": target_id, "relation_type": "is_a"}) + + for rel_val in stanza.get("relationship", []): + parts = rel_val.split() + if len(parts) >= 2: + rel_type = parts[0] + target_id = parts[1].split("!")[0].strip() + rows.append( + { + "source_id": source_id, + "target_id": target_id, + "relation_type": rel_type, + } + ) + + return pd.DataFrame(rows, columns=["source_id", "target_id", "relation_type"]) diff --git a/chebi_utils/sdf_extractor.py b/chebi_utils/sdf_extractor.py new file mode 100644 index 0000000..4777133 --- /dev/null +++ b/chebi_utils/sdf_extractor.py @@ -0,0 +1,85 @@ +"""Extract molecule data from ChEBI SDF files.""" + +from __future__ import annotations + +import gzip +from pathlib import Path + +import pandas as pd + + +def _iter_sdf_records(filepath: str | Path): + """Yield individual SDF records as strings.""" + opener = gzip.open if str(filepath).endswith(".gz") else open + current_record: list[str] = [] + + with opener(filepath, "rt", encoding="utf-8") as f: + for line in f: + current_record.append(line) + if line.strip() == "$$$$": + yield "".join(current_record) + current_record = [] + + +def _parse_sdf_record(record: str) -> dict[str, str]: + """Parse a single SDF record into a dict of data-item properties.""" + props: dict[str, str] = {} + lines = record.splitlines() + + if lines: + props["mol_name"] = lines[0].strip() + + i = 0 + while i < len(lines): + line = lines[i] + if line.startswith("> <") and line.rstrip().endswith(">"): + key = line.strip()[3:-1] + value_lines: list[str] = [] + i += 1 + while i < len(lines) and lines[i].strip() not in ("", "$$$$"): + value_lines.append(lines[i].strip()) + i += 1 + props[key] = "\n".join(value_lines) + else: + i += 1 + + return props + + +def extract_molecules(filepath: str | Path) -> pd.DataFrame: + """Extract molecule data from a ChEBI SDF file. + + Supports both plain (``.sdf``) and gzip-compressed (``.sdf.gz``) files. + + Parameters + ---------- + filepath : str or Path + Path to the ChEBI SDF (or SDF.gz) file. + + Returns + ------- + pd.DataFrame + DataFrame with one row per molecule. Columns depend on the properties + present in the file. Common columns (renamed for convenience): + chebi_id, name, inchi, inchikey, smiles, formula, charge, mass. + """ + records = [_parse_sdf_record(r) for r in _iter_sdf_records(filepath)] + + if not records: + return pd.DataFrame() + + df = pd.DataFrame(records) + + rename_map = { + "ChEBI ID": "chebi_id", + "ChEBI Name": "name", + "InChI": "inchi", + "InChIKey": "inchikey", + "SMILES": "smiles", + "Formulae": "formula", + "Charge": "charge", + "Mass": "mass", + } + df = df.rename(columns={k: v for k, v in rename_map.items() if k in df.columns}) + + return df diff --git a/chebi_utils/splitter.py b/chebi_utils/splitter.py new file mode 100644 index 0000000..dd86141 --- /dev/null +++ b/chebi_utils/splitter.py @@ -0,0 +1,105 @@ +"""Generate stratified train/validation/test splits from ChEBI DataFrames.""" + +from __future__ import annotations + +import numpy as np +import pandas as pd + + +def create_splits( + df: pd.DataFrame, + train_ratio: float = 0.8, + val_ratio: float = 0.1, + test_ratio: float = 0.1, + stratify_col: str | None = None, + seed: int = 42, +) -> dict[str, pd.DataFrame]: + """Create stratified train/validation/test splits of a DataFrame. + + Parameters + ---------- + df : pd.DataFrame + Input data to split. + train_ratio : float + Fraction of data for training (default 0.8). + val_ratio : float + Fraction of data for validation (default 0.1). + test_ratio : float + Fraction of data for testing (default 0.1). + stratify_col : str or None + Column name to use for stratification. If None, splits are random. + seed : int + Random seed for reproducibility. + + Returns + ------- + dict + Dictionary with keys ``'train'``, ``'val'``, ``'test'``, each + containing a DataFrame. + + Raises + ------ + ValueError + If the ratios do not sum to 1 or any ratio is outside ``[0, 1]``. + """ + if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6: + raise ValueError("train_ratio + val_ratio + test_ratio must equal 1.0") + if any(r < 0 or r > 1 for r in [train_ratio, val_ratio, test_ratio]): + raise ValueError("All ratios must be between 0 and 1") + + rng = np.random.default_rng(seed) + + if stratify_col is not None: + return _stratified_split(df, train_ratio, val_ratio, test_ratio, stratify_col, rng) + return _random_split(df, train_ratio, val_ratio, test_ratio, rng) + + +def _random_split( + df: pd.DataFrame, + train_ratio: float, + val_ratio: float, + test_ratio: float, # noqa: ARG001 + rng: np.random.Generator, +) -> dict[str, pd.DataFrame]: + indices = rng.permutation(len(df)) + n_train = int(len(df) * train_ratio) + n_val = int(len(df) * val_ratio) + + train_idx = indices[:n_train] + val_idx = indices[n_train : n_train + n_val] + test_idx = indices[n_train + n_val :] + + return { + "train": df.iloc[train_idx].reset_index(drop=True), + "val": df.iloc[val_idx].reset_index(drop=True), + "test": df.iloc[test_idx].reset_index(drop=True), + } + + +def _stratified_split( + df: pd.DataFrame, + train_ratio: float, + val_ratio: float, + test_ratio: float, # noqa: ARG001 + stratify_col: str, + rng: np.random.Generator, +) -> dict[str, pd.DataFrame]: + train_indices: list[int] = [] + val_indices: list[int] = [] + test_indices: list[int] = [] + + for _, group in df.groupby(stratify_col, sort=False): + group_indices = rng.permutation(np.array(group.index.tolist())) + n = len(group_indices) + n_train = max(1, int(n * train_ratio)) + n_val = max(0, int(n * val_ratio)) + + train_indices.extend(group_indices[:n_train].tolist()) + val_indices.extend(group_indices[n_train : n_train + n_val].tolist()) + test_indices.extend(group_indices[n_train + n_val :].tolist()) + + return { + "train": df.loc[train_indices].reset_index(drop=True), + "val": df.loc[val_indices].reset_index(drop=True), + "test": df.loc[test_indices].reset_index(drop=True), + } diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..717640a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,31 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "chebi-utils" +version = "0.1.0" +description = "Common processing functionality for the ChEBI ontology" +readme = "README.md" +license = { file = "LICENSE" } +requires-python = ">=3.10" +dependencies = [ + "numpy>=1.24", + "pandas>=2.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0", + "ruff>=0.4", +] + +[tool.ruff] +line-length = 99 +target-version = "py310" + +[tool.ruff.lint] +select = ["E", "F", "W", "I"] + +[tool.pytest.ini_options] +testpaths = ["tests"] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/fixtures/sample.obo b/tests/fixtures/sample.obo new file mode 100644 index 0000000..1fcc965 --- /dev/null +++ b/tests/fixtures/sample.obo @@ -0,0 +1,30 @@ +format-version: 1.2 +ontology: chebi + +[Term] +id: CHEBI:1 +name: compound A +def: "A test compound." [TestDB:001] +is_a: CHEBI:2 ! compound B +relationship: has_role CHEBI:3 ! some role + +[Term] +id: CHEBI:2 +name: compound B +def: "Another test compound." [TestDB:002] +is_a: CHEBI:4 ! root compound + +[Term] +id: CHEBI:3 +name: some role +def: "A biological role." [TestDB:003] + +[Term] +id: CHEBI:4 +name: root compound +def: "The root compound." [TestDB:004] +is_obsolete: true + +[Typedef] +id: has_role +name: has role diff --git a/tests/fixtures/sample.sdf b/tests/fixtures/sample.sdf new file mode 100644 index 0000000..f598c07 --- /dev/null +++ b/tests/fixtures/sample.sdf @@ -0,0 +1,58 @@ +compound_a + + 0 0 0 0 0 0 0 0 0 0999 V2000 +M END +> +CHEBI:1 + +> +compound A + +> +C + +> +InChI=1S/CH4/h1H4 + +> +VNWKTOKETHGBQD-UHFFFAOYSA-N + +> +CH4 + +> +0 + +> +16.043 + +$$$$ +compound_b + + 0 0 0 0 0 0 0 0 0 0999 V2000 +M END +> +CHEBI:2 + +> +compound B + +> +CC + +> +InChI=1S/C2H6/c1-2/h1-2H3 + +> +OTMSDBZUPAUEDD-UHFFFAOYSA-N + +> +C2H6 + +> +0 + +> +30.069 + +$$$$ diff --git a/tests/test_downloader.py b/tests/test_downloader.py new file mode 100644 index 0000000..1556cb0 --- /dev/null +++ b/tests/test_downloader.py @@ -0,0 +1,52 @@ +"""Tests for chebi_utils.downloader.""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import patch + +from chebi_utils.downloader import ( + CHEBI_OBO_URL, + CHEBI_SDF_URL, + download_chebi_obo, + download_chebi_sdf, +) + + +def test_download_chebi_obo_calls_urlretrieve(tmp_path): + with patch("chebi_utils.downloader.urllib.request.urlretrieve") as mock_retrieve: + result = download_chebi_obo(dest_dir=tmp_path) + mock_retrieve.assert_called_once_with(CHEBI_OBO_URL, tmp_path / "chebi.obo") + assert result == tmp_path / "chebi.obo" + + +def test_download_chebi_sdf_calls_urlretrieve(tmp_path): + with patch("chebi_utils.downloader.urllib.request.urlretrieve") as mock_retrieve: + result = download_chebi_sdf(dest_dir=tmp_path) + mock_retrieve.assert_called_once_with(CHEBI_SDF_URL, tmp_path / "chebi.sdf.gz") + assert result == tmp_path / "chebi.sdf.gz" + + +def test_download_chebi_obo_custom_filename(tmp_path): + with patch("chebi_utils.downloader.urllib.request.urlretrieve"): + result = download_chebi_obo(dest_dir=tmp_path, filename="my_chebi.obo") + assert result == tmp_path / "my_chebi.obo" + + +def test_download_chebi_sdf_custom_filename(tmp_path): + with patch("chebi_utils.downloader.urllib.request.urlretrieve"): + result = download_chebi_sdf(dest_dir=tmp_path, filename="my_chebi.sdf.gz") + assert result == tmp_path / "my_chebi.sdf.gz" + + +def test_download_creates_dest_dir(tmp_path): + new_dir = tmp_path / "subdir" / "nested" + with patch("chebi_utils.downloader.urllib.request.urlretrieve"): + download_chebi_obo(dest_dir=new_dir) + assert new_dir.is_dir() + + +def test_download_returns_path_object(tmp_path): + with patch("chebi_utils.downloader.urllib.request.urlretrieve"): + result = download_chebi_obo(dest_dir=str(tmp_path)) + assert isinstance(result, Path) diff --git a/tests/test_obo_extractor.py b/tests/test_obo_extractor.py new file mode 100644 index 0000000..e8b1980 --- /dev/null +++ b/tests/test_obo_extractor.py @@ -0,0 +1,72 @@ +"""Tests for chebi_utils.obo_extractor.""" + +from __future__ import annotations + +from pathlib import Path + +from chebi_utils.obo_extractor import extract_classes, extract_relations + +FIXTURES = Path(__file__).parent / "fixtures" +SAMPLE_OBO = FIXTURES / "sample.obo" + + +class TestExtractClasses: + def test_returns_dataframe_with_expected_columns(self): + df = extract_classes(SAMPLE_OBO) + assert list(df.columns) == ["id", "name", "definition", "is_obsolete"] + + def test_correct_number_of_terms(self): + df = extract_classes(SAMPLE_OBO) + assert len(df) == 4 + + def test_term_ids_are_present(self): + df = extract_classes(SAMPLE_OBO) + assert set(df["id"]) == {"CHEBI:1", "CHEBI:2", "CHEBI:3", "CHEBI:4"} + + def test_term_names_are_correct(self): + df = extract_classes(SAMPLE_OBO) + row = df[df["id"] == "CHEBI:1"].iloc[0] + assert row["name"] == "compound A" + + def test_obsolete_flag(self): + df = extract_classes(SAMPLE_OBO) + assert df[df["id"] == "CHEBI:4"].iloc[0]["is_obsolete"] + assert not df[df["id"] == "CHEBI:1"].iloc[0]["is_obsolete"] + + def test_definition_is_extracted(self): + df = extract_classes(SAMPLE_OBO) + row = df[df["id"] == "CHEBI:1"].iloc[0] + assert "test compound" in row["definition"] + + def test_typedef_stanzas_are_excluded(self): + df = extract_classes(SAMPLE_OBO) + assert "has_role" not in df["id"].values + + +class TestExtractRelations: + def test_returns_dataframe_with_expected_columns(self): + df = extract_relations(SAMPLE_OBO) + assert list(df.columns) == ["source_id", "target_id", "relation_type"] + + def test_isa_relations_extracted(self): + df = extract_relations(SAMPLE_OBO) + isa = df[df["relation_type"] == "is_a"] + assert len(isa) == 2 + + def test_typed_relation_extracted(self): + df = extract_relations(SAMPLE_OBO) + has_role = df[df["relation_type"] == "has_role"] + assert len(has_role) == 1 + assert has_role.iloc[0]["source_id"] == "CHEBI:1" + assert has_role.iloc[0]["target_id"] == "CHEBI:3" + + def test_isa_source_and_target(self): + df = extract_relations(SAMPLE_OBO) + isa = df[df["relation_type"] == "is_a"] + sources = set(isa["source_id"]) + assert "CHEBI:1" in sources + assert "CHEBI:2" in sources + + def test_total_relations_count(self): + df = extract_relations(SAMPLE_OBO) + assert len(df) == 3 diff --git a/tests/test_sdf_extractor.py b/tests/test_sdf_extractor.py new file mode 100644 index 0000000..63e0c03 --- /dev/null +++ b/tests/test_sdf_extractor.py @@ -0,0 +1,66 @@ +"""Tests for chebi_utils.sdf_extractor.""" + +from __future__ import annotations + +import gzip +from pathlib import Path + +from chebi_utils.sdf_extractor import extract_molecules + +FIXTURES = Path(__file__).parent / "fixtures" +SAMPLE_SDF = FIXTURES / "sample.sdf" + + +class TestExtractMolecules: + def test_returns_two_molecules(self): + df = extract_molecules(SAMPLE_SDF) + assert len(df) == 2 + + def test_chebi_id_column_present(self): + df = extract_molecules(SAMPLE_SDF) + assert "chebi_id" in df.columns + + def test_chebi_ids_correct(self): + df = extract_molecules(SAMPLE_SDF) + assert set(df["chebi_id"]) == {"CHEBI:1", "CHEBI:2"} + + def test_name_column_present(self): + df = extract_molecules(SAMPLE_SDF) + assert "name" in df.columns + + def test_smiles_column_present(self): + df = extract_molecules(SAMPLE_SDF) + assert "smiles" in df.columns + + def test_formula_column_present(self): + df = extract_molecules(SAMPLE_SDF) + assert "formula" in df.columns + + def test_inchi_column_present(self): + df = extract_molecules(SAMPLE_SDF) + assert "inchi" in df.columns + + def test_inchikey_column_present(self): + df = extract_molecules(SAMPLE_SDF) + assert "inchikey" in df.columns + + def test_molecule_properties(self): + df = extract_molecules(SAMPLE_SDF) + row = df[df["chebi_id"] == "CHEBI:1"].iloc[0] + assert row["name"] == "compound A" + assert row["smiles"] == "C" + assert row["formula"] == "CH4" + + def test_gzipped_sdf(self, tmp_path): + gz_path = tmp_path / "sample.sdf.gz" + with open(SAMPLE_SDF, "rb") as f_in, gzip.open(gz_path, "wb") as f_out: + f_out.write(f_in.read()) + df = extract_molecules(gz_path) + assert len(df) == 2 + assert set(df["chebi_id"]) == {"CHEBI:1", "CHEBI:2"} + + def test_empty_sdf_returns_empty_dataframe(self, tmp_path): + empty_sdf = tmp_path / "empty.sdf" + empty_sdf.write_text("") + df = extract_molecules(empty_sdf) + assert df.empty diff --git a/tests/test_splitter.py b/tests/test_splitter.py new file mode 100644 index 0000000..10b9434 --- /dev/null +++ b/tests/test_splitter.py @@ -0,0 +1,105 @@ +"""Tests for chebi_utils.splitter.""" + +from __future__ import annotations + +import pandas as pd +import pytest + +from chebi_utils.splitter import create_splits + + +@pytest.fixture +def sample_df(): + return pd.DataFrame( + { + "id": [f"CHEBI:{i}" for i in range(100)], + "category": (["A"] * 50) + (["B"] * 30) + (["C"] * 20), + } + ) + + +class TestCreateSplitsRandom: + def test_returns_three_splits(self, sample_df): + splits = create_splits(sample_df) + assert set(splits.keys()) == {"train", "val", "test"} + + def test_split_sizes_sum_to_total(self, sample_df): + splits = create_splits(sample_df) + total = sum(len(v) for v in splits.values()) + assert total == len(sample_df) + + def test_default_ratios(self, sample_df): + splits = create_splits(sample_df) + assert len(splits["train"]) == 80 + assert len(splits["val"]) == 10 + assert len(splits["test"]) == 10 + + def test_reproducible_with_same_seed(self, sample_df): + splits1 = create_splits(sample_df, seed=0) + splits2 = create_splits(sample_df, seed=0) + pd.testing.assert_frame_equal(splits1["train"], splits2["train"]) + + def test_different_seeds_give_different_splits(self, sample_df): + splits1 = create_splits(sample_df, seed=0) + splits2 = create_splits(sample_df, seed=1) + assert not splits1["train"]["id"].equals(splits2["train"]["id"]) + + def test_no_overlap_between_splits(self, sample_df): + splits = create_splits(sample_df) + train_ids = set(splits["train"]["id"]) + val_ids = set(splits["val"]["id"]) + test_ids = set(splits["test"]["id"]) + assert train_ids.isdisjoint(val_ids) + assert train_ids.isdisjoint(test_ids) + assert val_ids.isdisjoint(test_ids) + + def test_all_rows_covered(self, sample_df): + splits = create_splits(sample_df) + all_ids = set(splits["train"]["id"]) | set(splits["val"]["id"]) | set(splits["test"]["id"]) + assert all_ids == set(sample_df["id"]) + + def test_custom_ratios(self, sample_df): + splits = create_splits(sample_df, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1) + assert len(splits["train"]) == 70 + assert len(splits["val"]) == 20 + assert len(splits["test"]) == 10 + + def test_invalid_ratios_raise_error(self, sample_df): + with pytest.raises(ValueError, match="must equal 1.0"): + create_splits(sample_df, train_ratio=0.5, val_ratio=0.3, test_ratio=0.3) + + def test_negative_ratio_raises_error(self, sample_df): + with pytest.raises(ValueError, match="between 0 and 1"): + create_splits(sample_df, train_ratio=-0.1, val_ratio=0.6, test_ratio=0.5) + + +class TestCreateSplitsStratified: + def test_stratified_split_returns_three_splits(self, sample_df): + splits = create_splits(sample_df, stratify_col="category") + assert set(splits.keys()) == {"train", "val", "test"} + + def test_stratified_sizes_sum_to_total(self, sample_df): + splits = create_splits(sample_df, stratify_col="category") + total = sum(len(v) for v in splits.values()) + assert total == len(sample_df) + + def test_stratified_preserves_class_proportions(self, sample_df): + splits = create_splits(sample_df, stratify_col="category") + train = splits["train"] + counts = train["category"].value_counts(normalize=True) + # Category A: 50% of data -> ~50% of train + assert abs(counts.get("A", 0) - 0.5) < 0.1 + + def test_stratified_no_overlap(self, sample_df): + splits = create_splits(sample_df, stratify_col="category") + train_ids = set(splits["train"]["id"]) + val_ids = set(splits["val"]["id"]) + test_ids = set(splits["test"]["id"]) + assert train_ids.isdisjoint(val_ids) + assert train_ids.isdisjoint(test_ids) + assert val_ids.isdisjoint(test_ids) + + def test_stratified_reproducible(self, sample_df): + splits1 = create_splits(sample_df, stratify_col="category", seed=42) + splits2 = create_splits(sample_df, stratify_col="category", seed=42) + pd.testing.assert_frame_equal(splits1["train"], splits2["train"]) From eb6a7fe17c805f59c408b058441d5bc15557faeb Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 25 Feb 2026 09:12:28 +0000 Subject: [PATCH 3/6] Update downloader to accept version parameter with versioned archive URLs Co-authored-by: sfluegel05 <43573433+sfluegel05@users.noreply.github.com> --- chebi_utils/downloader.py | 41 +++++++++++++++++---- tests/test_downloader.py | 77 ++++++++++++++++++++++++++++++++------- 2 files changed, 97 insertions(+), 21 deletions(-) diff --git a/chebi_utils/downloader.py b/chebi_utils/downloader.py index 81358eb..2c320af 100644 --- a/chebi_utils/downloader.py +++ b/chebi_utils/downloader.py @@ -5,15 +5,33 @@ import urllib.request from pathlib import Path -CHEBI_OBO_URL = "https://ftp.ebi.ac.uk/pub/databases/chebi/ontology/chebi.obo" -CHEBI_SDF_URL = "https://ftp.ebi.ac.uk/pub/databases/chebi/SDF/ChEBI_complete.sdf.gz" +_CHEBI_LEGACY_VERSION_THRESHOLD = 245 -def download_chebi_obo(dest_dir: str | Path = ".", filename: str = "chebi.obo") -> Path: - """Download the ChEBI OBO ontology file from the EBI FTP server. +def _chebi_obo_url(version: int) -> str: + if version < _CHEBI_LEGACY_VERSION_THRESHOLD: + return f"https://ftp.ebi.ac.uk/pub/databases/chebi/archive/chebi_legacy/archive/rel{version}/ontology/chebi.obo" + return f"https://ftp.ebi.ac.uk/pub/databases/chebi/archive/rel{version}/ontology/chebi.obo" + + +def _chebi_sdf_url(version: int) -> str: + if version < _CHEBI_LEGACY_VERSION_THRESHOLD: + return f"https://ftp.ebi.ac.uk/pub/databases/chebi/archive/chebi_legacy/archive/rel{version}/ontology/chebi.obo" + return f"https://ftp.ebi.ac.uk/pub/databases/chebi/archive/rel{version}/SDF/chebi.sdf.gz" + + +def download_chebi_obo( + version: int, + dest_dir: str | Path = ".", + filename: str = "chebi.obo", +) -> Path: + """Download a versioned ChEBI OBO ontology file from the EBI FTP server. Parameters ---------- + version : int + ChEBI release version number (e.g. 230, 245, 250). + Versions below 245 are fetched from the legacy archive path. dest_dir : str or Path Directory where the file will be saved (created if it doesn't exist). filename : str @@ -27,15 +45,22 @@ def download_chebi_obo(dest_dir: str | Path = ".", filename: str = "chebi.obo") dest_dir = Path(dest_dir) dest_dir.mkdir(parents=True, exist_ok=True) dest_path = dest_dir / filename - urllib.request.urlretrieve(CHEBI_OBO_URL, dest_path) + urllib.request.urlretrieve(_chebi_obo_url(version), dest_path) return dest_path -def download_chebi_sdf(dest_dir: str | Path = ".", filename: str = "chebi.sdf.gz") -> Path: - """Download the ChEBI SDF file from the EBI FTP server. +def download_chebi_sdf( + version: int, + dest_dir: str | Path = ".", + filename: str = "chebi.sdf.gz", +) -> Path: + """Download a versioned ChEBI SDF file from the EBI FTP server. Parameters ---------- + version : int + ChEBI release version number (e.g. 230, 245, 250). + Versions below 245 are fetched from the legacy archive path. dest_dir : str or Path Directory where the file will be saved (created if it doesn't exist). filename : str @@ -49,5 +74,5 @@ def download_chebi_sdf(dest_dir: str | Path = ".", filename: str = "chebi.sdf.gz dest_dir = Path(dest_dir) dest_dir.mkdir(parents=True, exist_ok=True) dest_path = dest_dir / filename - urllib.request.urlretrieve(CHEBI_SDF_URL, dest_path) + urllib.request.urlretrieve(_chebi_sdf_url(version), dest_path) return dest_path diff --git a/tests/test_downloader.py b/tests/test_downloader.py index 1556cb0..126e791 100644 --- a/tests/test_downloader.py +++ b/tests/test_downloader.py @@ -6,47 +6,98 @@ from unittest.mock import patch from chebi_utils.downloader import ( - CHEBI_OBO_URL, - CHEBI_SDF_URL, + _chebi_obo_url, + _chebi_sdf_url, download_chebi_obo, download_chebi_sdf, ) +# --- URL helper tests --- -def test_download_chebi_obo_calls_urlretrieve(tmp_path): + +def test_obo_url_modern_version(): + url = _chebi_obo_url(245) + assert url == "https://ftp.ebi.ac.uk/pub/databases/chebi/archive/rel245/ontology/chebi.obo" + + +def test_obo_url_legacy_version(): + url = _chebi_obo_url(244) + assert ( + url + == "https://ftp.ebi.ac.uk/pub/databases/chebi/archive/chebi_legacy/archive/rel244/ontology/chebi.obo" + ) + + +def test_sdf_url_modern_version(): + url = _chebi_sdf_url(245) + assert url == "https://ftp.ebi.ac.uk/pub/databases/chebi/archive/rel245/SDF/chebi.sdf.gz" + + +def test_sdf_url_legacy_version(): + url = _chebi_sdf_url(230) + assert ( + url + == "https://ftp.ebi.ac.uk/pub/databases/chebi/archive/chebi_legacy/archive/rel230/ontology/chebi.obo" + ) + + +# --- download_chebi_obo tests --- + + +def test_download_chebi_obo_calls_urlretrieve_modern(tmp_path): with patch("chebi_utils.downloader.urllib.request.urlretrieve") as mock_retrieve: - result = download_chebi_obo(dest_dir=tmp_path) - mock_retrieve.assert_called_once_with(CHEBI_OBO_URL, tmp_path / "chebi.obo") + result = download_chebi_obo(version=250, dest_dir=tmp_path) + mock_retrieve.assert_called_once_with(_chebi_obo_url(250), tmp_path / "chebi.obo") assert result == tmp_path / "chebi.obo" -def test_download_chebi_sdf_calls_urlretrieve(tmp_path): +def test_download_chebi_obo_calls_urlretrieve_legacy(tmp_path): with patch("chebi_utils.downloader.urllib.request.urlretrieve") as mock_retrieve: - result = download_chebi_sdf(dest_dir=tmp_path) - mock_retrieve.assert_called_once_with(CHEBI_SDF_URL, tmp_path / "chebi.sdf.gz") - assert result == tmp_path / "chebi.sdf.gz" + result = download_chebi_obo(version=230, dest_dir=tmp_path) + mock_retrieve.assert_called_once_with(_chebi_obo_url(230), tmp_path / "chebi.obo") + assert result == tmp_path / "chebi.obo" def test_download_chebi_obo_custom_filename(tmp_path): with patch("chebi_utils.downloader.urllib.request.urlretrieve"): - result = download_chebi_obo(dest_dir=tmp_path, filename="my_chebi.obo") + result = download_chebi_obo(version=250, dest_dir=tmp_path, filename="my_chebi.obo") assert result == tmp_path / "my_chebi.obo" +# --- download_chebi_sdf tests --- + + +def test_download_chebi_sdf_calls_urlretrieve_modern(tmp_path): + with patch("chebi_utils.downloader.urllib.request.urlretrieve") as mock_retrieve: + result = download_chebi_sdf(version=250, dest_dir=tmp_path) + mock_retrieve.assert_called_once_with(_chebi_sdf_url(250), tmp_path / "chebi.sdf.gz") + assert result == tmp_path / "chebi.sdf.gz" + + +def test_download_chebi_sdf_calls_urlretrieve_legacy(tmp_path): + with patch("chebi_utils.downloader.urllib.request.urlretrieve") as mock_retrieve: + result = download_chebi_sdf(version=230, dest_dir=tmp_path) + mock_retrieve.assert_called_once_with(_chebi_sdf_url(230), tmp_path / "chebi.sdf.gz") + assert result == tmp_path / "chebi.sdf.gz" + + def test_download_chebi_sdf_custom_filename(tmp_path): with patch("chebi_utils.downloader.urllib.request.urlretrieve"): - result = download_chebi_sdf(dest_dir=tmp_path, filename="my_chebi.sdf.gz") + result = download_chebi_sdf(version=250, dest_dir=tmp_path, filename="my_chebi.sdf.gz") assert result == tmp_path / "my_chebi.sdf.gz" +# --- shared behaviour tests --- + + def test_download_creates_dest_dir(tmp_path): new_dir = tmp_path / "subdir" / "nested" with patch("chebi_utils.downloader.urllib.request.urlretrieve"): - download_chebi_obo(dest_dir=new_dir) + download_chebi_obo(version=250, dest_dir=new_dir) assert new_dir.is_dir() def test_download_returns_path_object(tmp_path): with patch("chebi_utils.downloader.urllib.request.urlretrieve"): - result = download_chebi_obo(dest_dir=str(tmp_path)) + result = download_chebi_obo(version=250, dest_dir=str(tmp_path)) assert isinstance(result, Path) From 1df6eeef9861b90a7001c5e2a84895bb2e4bd97d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 25 Feb 2026 09:19:57 +0000 Subject: [PATCH 4/6] Rewrite obo_extractor to use fastobo and return a networkx DiGraph Co-authored-by: sfluegel05 <43573433+sfluegel05@users.noreply.github.com> --- chebi_utils/__init__.py | 5 +- chebi_utils/obo_extractor.py | 164 +++++++++++++++++++---------------- pyproject.toml | 2 + tests/fixtures/sample.obo | 23 +++-- tests/test_obo_extractor.py | 136 ++++++++++++++++------------- 5 files changed, 181 insertions(+), 149 deletions(-) diff --git a/chebi_utils/__init__.py b/chebi_utils/__init__.py index 41a117f..8367bab 100644 --- a/chebi_utils/__init__.py +++ b/chebi_utils/__init__.py @@ -1,13 +1,12 @@ from chebi_utils.downloader import download_chebi_obo, download_chebi_sdf -from chebi_utils.obo_extractor import extract_classes, extract_relations +from chebi_utils.obo_extractor import build_chebi_graph from chebi_utils.sdf_extractor import extract_molecules from chebi_utils.splitter import create_splits __all__ = [ "download_chebi_obo", "download_chebi_sdf", - "extract_classes", - "extract_relations", + "build_chebi_graph", "extract_molecules", "create_splits", ] diff --git a/chebi_utils/obo_extractor.py b/chebi_utils/obo_extractor.py index 7279814..4ef37ff 100644 --- a/chebi_utils/obo_extractor.py +++ b/chebi_utils/obo_extractor.py @@ -1,67 +1,80 @@ -"""Extract classes and relations from ChEBI OBO ontology files.""" +"""Extract ChEBI ontology data using fastobo and build a networkx graph.""" from __future__ import annotations from pathlib import Path -import pandas as pd +import fastobo +import networkx as nx -def _parse_obo_stanzas(filepath: str | Path) -> list[dict[str, list[str]]]: - """Parse an OBO file and return a list of stanza dicts.""" - stanzas: list[dict[str, list[str]]] = [] - current_stanza: dict[str, list[str]] | None = None +def _chebi_id_to_int(chebi_id: str) -> int: + """Convert 'CHEBI:123' to 123.""" + return int(chebi_id.split(":")[1]) - with open(filepath, encoding="utf-8") as f: - for line in f: - line = line.strip() - if not line or line.startswith("!"): - continue - if line.startswith("["): - if current_stanza is not None: - stanzas.append(current_stanza) - stanza_type = line.strip("[]") - current_stanza = {"_type": [stanza_type]} - elif current_stanza is not None and ":" in line: - key, _, value = line.partition(":") - current_stanza.setdefault(key.strip(), []).append(value.strip()) - - if current_stanza is not None: - stanzas.append(current_stanza) - - return stanzas - - -def extract_classes(filepath: str | Path) -> pd.DataFrame: - """Extract ontology classes (terms) from a ChEBI OBO file. - Parameters - ---------- - filepath : str or Path - Path to the ChEBI OBO file. +def _term_data(doc: "fastobo.term.TermFrame") -> dict | None: + """Extract data from a single fastobo TermFrame. Returns ------- - pd.DataFrame - DataFrame with columns: id, name, definition, is_obsolete. + dict or None + Parsed term data, or ``None`` if the term is marked as obsolete. """ - stanzas = _parse_obo_stanzas(filepath) - rows = [] - for stanza in stanzas: - if stanza.get("_type", [None])[0] != "Term": - continue - row = { - "id": stanza.get("id", [None])[0], - "name": stanza.get("name", [None])[0], - "definition": stanza.get("def", [None])[0], - "is_obsolete": stanza.get("is_obsolete", ["false"])[0] == "true", - } - rows.append(row) - return pd.DataFrame(rows, columns=["id", "name", "definition", "is_obsolete"]) - - -def extract_relations(filepath: str | Path) -> pd.DataFrame: - """Extract class relations from a ChEBI OBO file. + parents: list[int] = [] + has_part: set[int] = set() + name: str | None = None + smiles: str | None = None + subset: str | None = None + + for clause in doc: + if isinstance(clause, fastobo.term.IsObsoleteClause): + if clause.obsolete: + return None + elif isinstance(clause, fastobo.term.PropertyValueClause): + pv = clause.property_value + if str(pv.relation) in ( + "chemrof:smiles_string", + "http://purl.obolibrary.org/obo/chebi/smiles", + ): + smiles = pv.value + elif isinstance(clause, fastobo.term.SynonymClause): + if "SMILES" in clause.raw_value() and smiles is None: + smiles = clause.raw_value().split('"')[1] + elif isinstance(clause, fastobo.term.RelationshipClause): + if str(clause.typedef) == "has_part": + has_part.add(_chebi_id_to_int(str(clause.term))) + elif isinstance(clause, fastobo.term.IsAClause): + parents.append(_chebi_id_to_int(str(clause.term))) + elif isinstance(clause, fastobo.term.NameClause): + name = str(clause.name) + elif isinstance(clause, fastobo.term.SubsetClause): + subset = str(clause.subset) + + return { + "id": _chebi_id_to_int(str(doc.id)), + "parents": parents, + "has_part": has_part, + "name": name, + "smiles": smiles, + "subset": subset, + } + + +def build_chebi_graph(filepath: str | Path) -> nx.DiGraph: + """Parse a ChEBI OBO file and build a directed graph of ontology terms. + + ``xref:`` lines are stripped before parsing as they can cause fastobo + errors on some ChEBI releases. Only non-obsolete CHEBI-prefixed terms + are included. + + **Nodes** are integer CHEBI IDs (e.g. ``1`` for ``CHEBI:1``) with + attributes ``name``, ``smiles``, and ``subset``. + + **Edges** carry a ``relation`` attribute and represent: + + - ``is_a`` — directed from child to parent + - ``has_part`` — directed from whole to part Parameters ---------- @@ -70,34 +83,31 @@ def extract_relations(filepath: str | Path) -> pd.DataFrame: Returns ------- - pd.DataFrame - DataFrame with columns: source_id, target_id, relation_type. + nx.DiGraph + Directed graph of ChEBI ontology terms and their relationships. """ - stanzas = _parse_obo_stanzas(filepath) - rows = [] + with open(filepath, encoding="utf-8") as f: + content = "\n".join(line for line in f if not line.startswith("xref:")) + + graph: nx.DiGraph = nx.DiGraph() - for stanza in stanzas: - if stanza.get("_type", [None])[0] != "Term": + for frame in fastobo.loads(content): + if not ( + frame and isinstance(frame.id, fastobo.id.PrefixedIdent) and frame.id.prefix == "CHEBI" + ): continue - source_id = stanza.get("id", [None])[0] - if source_id is None: + + term = _term_data(frame) + if term is None: continue - for is_a_val in stanza.get("is_a", []): - target_id = is_a_val.split("!")[0].strip() - rows.append({"source_id": source_id, "target_id": target_id, "relation_type": "is_a"}) - - for rel_val in stanza.get("relationship", []): - parts = rel_val.split() - if len(parts) >= 2: - rel_type = parts[0] - target_id = parts[1].split("!")[0].strip() - rows.append( - { - "source_id": source_id, - "target_id": target_id, - "relation_type": rel_type, - } - ) - - return pd.DataFrame(rows, columns=["source_id", "target_id", "relation_type"]) + node_id = term["id"] + graph.add_node(node_id, name=term["name"], smiles=term["smiles"], subset=term["subset"]) + + for parent_id in term["parents"]: + graph.add_edge(node_id, parent_id, relation="is_a") + + for part_id in term["has_part"]: + graph.add_edge(node_id, part_id, relation="has_part") + + return graph diff --git a/pyproject.toml b/pyproject.toml index 717640a..1a23ab9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,8 @@ readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" dependencies = [ + "fastobo>=0.14", + "networkx>=3.0", "numpy>=1.24", "pandas>=2.0", ] diff --git a/tests/fixtures/sample.obo b/tests/fixtures/sample.obo index 1fcc965..2b486fa 100644 --- a/tests/fixtures/sample.obo +++ b/tests/fixtures/sample.obo @@ -6,25 +6,32 @@ id: CHEBI:1 name: compound A def: "A test compound." [TestDB:001] is_a: CHEBI:2 ! compound B -relationship: has_role CHEBI:3 ! some role +relationship: has_part CHEBI:3 ! methyl group +property_value: http://purl.obolibrary.org/obo/chebi/smiles "C" xsd:string [Term] id: CHEBI:2 name: compound B def: "Another test compound." [TestDB:002] -is_a: CHEBI:4 ! root compound +is_a: CHEBI:5 ! root compound [Term] id: CHEBI:3 -name: some role -def: "A biological role." [TestDB:003] +name: methyl group +def: "A methyl group." [TestDB:003] +subset: 3_STAR [Term] id: CHEBI:4 -name: root compound -def: "The root compound." [TestDB:004] +name: obsolete term +def: "This term is obsolete." [TestDB:004] is_obsolete: true +[Term] +id: CHEBI:5 +name: root compound +def: "The root compound." [TestDB:005] + [Typedef] -id: has_role -name: has role +id: has_part +name: has part diff --git a/tests/test_obo_extractor.py b/tests/test_obo_extractor.py index e8b1980..ed3f8e5 100644 --- a/tests/test_obo_extractor.py +++ b/tests/test_obo_extractor.py @@ -4,69 +4,83 @@ from pathlib import Path -from chebi_utils.obo_extractor import extract_classes, extract_relations +import networkx as nx + +from chebi_utils.obo_extractor import build_chebi_graph FIXTURES = Path(__file__).parent / "fixtures" SAMPLE_OBO = FIXTURES / "sample.obo" -class TestExtractClasses: - def test_returns_dataframe_with_expected_columns(self): - df = extract_classes(SAMPLE_OBO) - assert list(df.columns) == ["id", "name", "definition", "is_obsolete"] - - def test_correct_number_of_terms(self): - df = extract_classes(SAMPLE_OBO) - assert len(df) == 4 - - def test_term_ids_are_present(self): - df = extract_classes(SAMPLE_OBO) - assert set(df["id"]) == {"CHEBI:1", "CHEBI:2", "CHEBI:3", "CHEBI:4"} - - def test_term_names_are_correct(self): - df = extract_classes(SAMPLE_OBO) - row = df[df["id"] == "CHEBI:1"].iloc[0] - assert row["name"] == "compound A" - - def test_obsolete_flag(self): - df = extract_classes(SAMPLE_OBO) - assert df[df["id"] == "CHEBI:4"].iloc[0]["is_obsolete"] - assert not df[df["id"] == "CHEBI:1"].iloc[0]["is_obsolete"] - - def test_definition_is_extracted(self): - df = extract_classes(SAMPLE_OBO) - row = df[df["id"] == "CHEBI:1"].iloc[0] - assert "test compound" in row["definition"] - - def test_typedef_stanzas_are_excluded(self): - df = extract_classes(SAMPLE_OBO) - assert "has_role" not in df["id"].values - - -class TestExtractRelations: - def test_returns_dataframe_with_expected_columns(self): - df = extract_relations(SAMPLE_OBO) - assert list(df.columns) == ["source_id", "target_id", "relation_type"] - - def test_isa_relations_extracted(self): - df = extract_relations(SAMPLE_OBO) - isa = df[df["relation_type"] == "is_a"] - assert len(isa) == 2 - - def test_typed_relation_extracted(self): - df = extract_relations(SAMPLE_OBO) - has_role = df[df["relation_type"] == "has_role"] - assert len(has_role) == 1 - assert has_role.iloc[0]["source_id"] == "CHEBI:1" - assert has_role.iloc[0]["target_id"] == "CHEBI:3" - - def test_isa_source_and_target(self): - df = extract_relations(SAMPLE_OBO) - isa = df[df["relation_type"] == "is_a"] - sources = set(isa["source_id"]) - assert "CHEBI:1" in sources - assert "CHEBI:2" in sources - - def test_total_relations_count(self): - df = extract_relations(SAMPLE_OBO) - assert len(df) == 3 +class TestBuildChebiGraph: + def test_returns_directed_graph(self): + g = build_chebi_graph(SAMPLE_OBO) + assert isinstance(g, nx.DiGraph) + + def test_correct_number_of_nodes(self): + # CHEBI:4 is obsolete and must be excluded -> 4 nodes remain + g = build_chebi_graph(SAMPLE_OBO) + assert len(g.nodes) == 4 + + def test_node_ids_are_integers(self): + g = build_chebi_graph(SAMPLE_OBO) + assert all(isinstance(n, int) for n in g.nodes) + + def test_expected_nodes_present(self): + g = build_chebi_graph(SAMPLE_OBO) + assert set(g.nodes) == {1, 2, 3, 5} + + def test_obsolete_term_excluded(self): + g = build_chebi_graph(SAMPLE_OBO) + assert 4 not in g.nodes + + def test_node_name_attribute(self): + g = build_chebi_graph(SAMPLE_OBO) + assert g.nodes[1]["name"] == "compound A" + assert g.nodes[2]["name"] == "compound B" + + def test_smiles_extracted_from_property_value(self): + g = build_chebi_graph(SAMPLE_OBO) + assert g.nodes[1]["smiles"] == "C" + + def test_smiles_none_when_absent(self): + g = build_chebi_graph(SAMPLE_OBO) + assert g.nodes[2]["smiles"] is None + + def test_subset_extracted(self): + g = build_chebi_graph(SAMPLE_OBO) + assert g.nodes[3]["subset"] == "3_STAR" + + def test_subset_none_when_absent(self): + g = build_chebi_graph(SAMPLE_OBO) + assert g.nodes[1]["subset"] is None + + def test_isa_edge_present(self): + g = build_chebi_graph(SAMPLE_OBO) + # CHEBI:1 is_a CHEBI:2 + assert g.has_edge(1, 2) + assert g.edges[1, 2]["relation"] == "is_a" + + def test_has_part_edge_present(self): + g = build_chebi_graph(SAMPLE_OBO) + # CHEBI:1 has_part CHEBI:3 + assert g.has_edge(1, 3) + assert g.edges[1, 3]["relation"] == "has_part" + + def test_total_edge_count(self): + g = build_chebi_graph(SAMPLE_OBO) + # 1->2 (is_a), 1->3 (has_part), 2->5 (is_a) + assert len(g.edges) == 3 + + def test_typedef_stanza_excluded(self): + g = build_chebi_graph(SAMPLE_OBO) + # "has_part" Typedef id is not numeric CHEBI ID, should not appear as node + assert "has_part" not in g.nodes + + def test_xref_lines_do_not_break_parsing(self, tmp_path): + obo_with_xrefs = tmp_path / "xref.obo" + obo_with_xrefs.write_text( + "format-version: 1.2\n[Term]\nid: CHEBI:10\nname: test\nxref: Reaxys:123456\n" + ) + g = build_chebi_graph(obo_with_xrefs) + assert 10 in g.nodes From b2e747a26a042d7a0c89fccaa8f608778aaccef9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 25 Feb 2026 09:27:13 +0000 Subject: [PATCH 5/6] Add RDKit Mol parsing to SDF extractor with partial sanitisation Co-authored-by: sfluegel05 <43573433+sfluegel05@users.noreply.github.com> --- chebi_utils/sdf_extractor.py | 111 +++++++++++++++++++++++++++++------ pyproject.toml | 1 + tests/fixtures/sample.sdf | 10 +++- tests/test_sdf_extractor.py | 34 +++++++++++ 4 files changed, 137 insertions(+), 19 deletions(-) diff --git a/chebi_utils/sdf_extractor.py b/chebi_utils/sdf_extractor.py index 4777133..6c2e63b 100644 --- a/chebi_utils/sdf_extractor.py +++ b/chebi_utils/sdf_extractor.py @@ -3,9 +3,56 @@ from __future__ import annotations import gzip +import warnings from pathlib import Path import pandas as pd +from rdkit import Chem + + +def _update_mol_valences(mol: Chem.Mol) -> Chem.Mol: + """Mark all atoms as having no implicit hydrogens to preserve molfile valences.""" + for atom in mol.GetAtoms(): + atom.SetNoImplicit(True) + return mol + + +def _parse_molblock(molblock: str, chebi_id: str | None = None) -> Chem.Mol | None: + """Parse a V2000/V3000 molblock into an RDKit Mol object. + + Uses partial sanitisation to handle ChEBI molecules with unusual valences + or radicals. + + Parameters + ---------- + molblock : str + The molblock string (header + atom/bond table + ``M END``). + chebi_id : str or None + Used only for the warning message when parsing fails. + + Returns + ------- + Chem.Mol or None + Parsed molecule, or ``None`` if parsing failed. + """ + mol = Chem.MolFromMolBlock(molblock, sanitize=False, removeHs=False) + if mol is None: + warnings.warn(f"Failed to parse molblock for {chebi_id}", stacklevel=2) + return None + mol = _update_mol_valences(mol) + Chem.SanitizeMol( + mol, + sanitizeOps=( + Chem.SanitizeFlags.SANITIZE_FINDRADICALS + | Chem.SanitizeFlags.SANITIZE_KEKULIZE + | Chem.SanitizeFlags.SANITIZE_SETAROMATICITY + | Chem.SanitizeFlags.SANITIZE_SETCONJUGATION + | Chem.SanitizeFlags.SANITIZE_SETHYBRIDIZATION + | Chem.SanitizeFlags.SANITIZE_SYMMRINGS + ), + catchErrors=True, + ) + return mol def _iter_sdf_records(filepath: str | Path): @@ -21,19 +68,39 @@ def _iter_sdf_records(filepath: str | Path): current_record = [] -def _parse_sdf_record(record: str) -> dict[str, str]: - """Parse a single SDF record into a dict of data-item properties.""" - props: dict[str, str] = {} - lines = record.splitlines() +def _parse_sdf_record(record: str) -> tuple[dict[str, str], str]: + """Parse a single SDF record. - if lines: - props["mol_name"] = lines[0].strip() - - i = 0 + Returns + ------- + tuple[dict[str, str], str] + ``(props, molblock)`` where *props* is a dict of data-item key/values + and *molblock* is the raw connection-table string. + """ + props: dict[str, str] = {} + lines = record.splitlines(keepends=True) + + # Collect molblock: everything up to (but not including) the first "> <" tag + molblock_lines: list[str] = [] + data_start = len(lines) + for idx, line in enumerate(lines): + stripped = line.strip() + if stripped.startswith("> <") or stripped == "$$$$": + data_start = idx + break + molblock_lines.append(line) + molblock = "".join(molblock_lines) + + # Extract header name (first line of molblock) + if molblock_lines: + props["mol_name"] = molblock_lines[0].strip() + + # Parse data items + i = data_start while i < len(lines): - line = lines[i] - if line.startswith("> <") and line.rstrip().endswith(">"): - key = line.strip()[3:-1] + line = lines[i].strip() + if line.startswith("> <") and line.endswith(">"): + key = line[3:-1] value_lines: list[str] = [] i += 1 while i < len(lines) and lines[i].strip() not in ("", "$$$$"): @@ -43,13 +110,15 @@ def _parse_sdf_record(record: str) -> dict[str, str]: else: i += 1 - return props + return props, molblock def extract_molecules(filepath: str | Path) -> pd.DataFrame: """Extract molecule data from a ChEBI SDF file. Supports both plain (``.sdf``) and gzip-compressed (``.sdf.gz``) files. + Each molecule is parsed into an RDKit ``Mol`` object stored in the ``mol`` + column. Molecules that cannot be parsed result in ``None`` in that column. Parameters ---------- @@ -61,14 +130,19 @@ def extract_molecules(filepath: str | Path) -> pd.DataFrame: pd.DataFrame DataFrame with one row per molecule. Columns depend on the properties present in the file. Common columns (renamed for convenience): - chebi_id, name, inchi, inchikey, smiles, formula, charge, mass. + chebi_id, name, inchi, inchikey, smiles, formula, charge, mass, mol. """ - records = [_parse_sdf_record(r) for r in _iter_sdf_records(filepath)] - - if not records: + rows = [] + molblocks = [] + for record in _iter_sdf_records(filepath): + props, molblock = _parse_sdf_record(record) + rows.append(props) + molblocks.append(molblock) + + if not rows: return pd.DataFrame() - df = pd.DataFrame(records) + df = pd.DataFrame(rows) rename_map = { "ChEBI ID": "chebi_id", @@ -82,4 +156,7 @@ def extract_molecules(filepath: str | Path) -> pd.DataFrame: } df = df.rename(columns={k: v for k, v in rename_map.items() if k in df.columns}) + chebi_ids = df["chebi_id"].tolist() if "chebi_id" in df.columns else [None] * len(df) + df["mol"] = [_parse_molblock(mb, cid) for mb, cid in zip(molblocks, chebi_ids, strict=False)] + return df diff --git a/pyproject.toml b/pyproject.toml index 1a23ab9..69e3be2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "networkx>=3.0", "numpy>=1.24", "pandas>=2.0", + "rdkit>=2022.09", ] [project.optional-dependencies] diff --git a/tests/fixtures/sample.sdf b/tests/fixtures/sample.sdf index f598c07..e877215 100644 --- a/tests/fixtures/sample.sdf +++ b/tests/fixtures/sample.sdf @@ -1,6 +1,8 @@ compound_a + RDKit - 0 0 0 0 0 0 0 0 0 0999 V2000 + 1 0 0 0 0 0 0 0 0 0999 V2000 + 0.0000 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0 M END > CHEBI:1 @@ -28,8 +30,12 @@ CH4 $$$$ compound_b + RDKit - 0 0 0 0 0 0 0 0 0 0999 V2000 + 2 1 0 0 0 0 0 0 0 0999 V2000 + 0.0000 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0 + 1.5400 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0 + 1 2 1 0 M END > CHEBI:2 diff --git a/tests/test_sdf_extractor.py b/tests/test_sdf_extractor.py index 63e0c03..b7be8a5 100644 --- a/tests/test_sdf_extractor.py +++ b/tests/test_sdf_extractor.py @@ -5,6 +5,8 @@ import gzip from pathlib import Path +from rdkit.Chem import rdchem + from chebi_utils.sdf_extractor import extract_molecules FIXTURES = Path(__file__).parent / "fixtures" @@ -44,6 +46,28 @@ def test_inchikey_column_present(self): df = extract_molecules(SAMPLE_SDF) assert "inchikey" in df.columns + def test_mol_column_present(self): + df = extract_molecules(SAMPLE_SDF) + assert "mol" in df.columns + + def test_mol_objects_are_rdkit_mol(self): + df = extract_molecules(SAMPLE_SDF) + for mol in df["mol"]: + assert isinstance(mol, rdchem.Mol) + + def test_mol_atom_counts(self): + df = extract_molecules(SAMPLE_SDF) + row1 = df[df["chebi_id"] == "CHEBI:1"].iloc[0] + row2 = df[df["chebi_id"] == "CHEBI:2"].iloc[0] + assert row1["mol"].GetNumAtoms() == 1 # methane: 1 C + assert row2["mol"].GetNumAtoms() == 2 # ethane: 2 C + + def test_mol_sanitized(self): + df = extract_molecules(SAMPLE_SDF) + for mol in df["mol"]: + # Aromaticity flags should be set (sanitize applied) + assert mol is not None + def test_molecule_properties(self): df = extract_molecules(SAMPLE_SDF) row = df[df["chebi_id"] == "CHEBI:1"].iloc[0] @@ -58,9 +82,19 @@ def test_gzipped_sdf(self, tmp_path): df = extract_molecules(gz_path) assert len(df) == 2 assert set(df["chebi_id"]) == {"CHEBI:1", "CHEBI:2"} + assert all(isinstance(m, rdchem.Mol) for m in df["mol"]) def test_empty_sdf_returns_empty_dataframe(self, tmp_path): empty_sdf = tmp_path / "empty.sdf" empty_sdf.write_text("") df = extract_molecules(empty_sdf) assert df.empty + + def test_unparseable_molblock_gives_none(self, tmp_path, recwarn): + bad_sdf = tmp_path / "bad.sdf" + bad_sdf.write_text( + "bad_mol\n\n 0 0 0 0 0 0 0 0 0 0999 V2000\nM END\n" + "> \nCHEBI:99\n\n$$$$\n" + ) + df = extract_molecules(bad_sdf) + assert df.iloc[0]["mol"] is None From 361ebd7a2d82774d219ffe71374b3be939d5db10 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 25 Feb 2026 09:31:31 +0000 Subject: [PATCH 6/6] Store ChEBI IDs as strings in the ontology graph Co-authored-by: sfluegel05 <43573433+sfluegel05@users.noreply.github.com> --- chebi_utils/obo_extractor.py | 18 +++++++++--------- tests/test_obo_extractor.py | 30 +++++++++++++++--------------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/chebi_utils/obo_extractor.py b/chebi_utils/obo_extractor.py index 4ef37ff..78224cb 100644 --- a/chebi_utils/obo_extractor.py +++ b/chebi_utils/obo_extractor.py @@ -8,9 +8,9 @@ import networkx as nx -def _chebi_id_to_int(chebi_id: str) -> int: - """Convert 'CHEBI:123' to 123.""" - return int(chebi_id.split(":")[1]) +def _chebi_id_to_str(chebi_id: str) -> str: + """Convert 'CHEBI:123' to '123' (string).""" + return chebi_id.split(":")[1] def _term_data(doc: "fastobo.term.TermFrame") -> dict | None: @@ -21,8 +21,8 @@ def _term_data(doc: "fastobo.term.TermFrame") -> dict | None: dict or None Parsed term data, or ``None`` if the term is marked as obsolete. """ - parents: list[int] = [] - has_part: set[int] = set() + parents: list[str] = [] + has_part: set[str] = set() name: str | None = None smiles: str | None = None subset: str | None = None @@ -43,16 +43,16 @@ def _term_data(doc: "fastobo.term.TermFrame") -> dict | None: smiles = clause.raw_value().split('"')[1] elif isinstance(clause, fastobo.term.RelationshipClause): if str(clause.typedef) == "has_part": - has_part.add(_chebi_id_to_int(str(clause.term))) + has_part.add(_chebi_id_to_str(str(clause.term))) elif isinstance(clause, fastobo.term.IsAClause): - parents.append(_chebi_id_to_int(str(clause.term))) + parents.append(_chebi_id_to_str(str(clause.term))) elif isinstance(clause, fastobo.term.NameClause): name = str(clause.name) elif isinstance(clause, fastobo.term.SubsetClause): subset = str(clause.subset) return { - "id": _chebi_id_to_int(str(doc.id)), + "id": _chebi_id_to_str(str(doc.id)), "parents": parents, "has_part": has_part, "name": name, @@ -68,7 +68,7 @@ def build_chebi_graph(filepath: str | Path) -> nx.DiGraph: errors on some ChEBI releases. Only non-obsolete CHEBI-prefixed terms are included. - **Nodes** are integer CHEBI IDs (e.g. ``1`` for ``CHEBI:1``) with + **Nodes** are string CHEBI IDs (e.g. ``"1"`` for ``CHEBI:1``) with attributes ``name``, ``smiles``, and ``subset``. **Edges** carry a ``relation`` attribute and represent: diff --git a/tests/test_obo_extractor.py b/tests/test_obo_extractor.py index ed3f8e5..d5a50e6 100644 --- a/tests/test_obo_extractor.py +++ b/tests/test_obo_extractor.py @@ -22,50 +22,50 @@ def test_correct_number_of_nodes(self): g = build_chebi_graph(SAMPLE_OBO) assert len(g.nodes) == 4 - def test_node_ids_are_integers(self): + def test_node_ids_are_strings(self): g = build_chebi_graph(SAMPLE_OBO) - assert all(isinstance(n, int) for n in g.nodes) + assert all(isinstance(n, str) for n in g.nodes) def test_expected_nodes_present(self): g = build_chebi_graph(SAMPLE_OBO) - assert set(g.nodes) == {1, 2, 3, 5} + assert set(g.nodes) == {"1", "2", "3", "5"} def test_obsolete_term_excluded(self): g = build_chebi_graph(SAMPLE_OBO) - assert 4 not in g.nodes + assert "4" not in g.nodes def test_node_name_attribute(self): g = build_chebi_graph(SAMPLE_OBO) - assert g.nodes[1]["name"] == "compound A" - assert g.nodes[2]["name"] == "compound B" + assert g.nodes["1"]["name"] == "compound A" + assert g.nodes["2"]["name"] == "compound B" def test_smiles_extracted_from_property_value(self): g = build_chebi_graph(SAMPLE_OBO) - assert g.nodes[1]["smiles"] == "C" + assert g.nodes["1"]["smiles"] == "C" def test_smiles_none_when_absent(self): g = build_chebi_graph(SAMPLE_OBO) - assert g.nodes[2]["smiles"] is None + assert g.nodes["2"]["smiles"] is None def test_subset_extracted(self): g = build_chebi_graph(SAMPLE_OBO) - assert g.nodes[3]["subset"] == "3_STAR" + assert g.nodes["3"]["subset"] == "3_STAR" def test_subset_none_when_absent(self): g = build_chebi_graph(SAMPLE_OBO) - assert g.nodes[1]["subset"] is None + assert g.nodes["1"]["subset"] is None def test_isa_edge_present(self): g = build_chebi_graph(SAMPLE_OBO) # CHEBI:1 is_a CHEBI:2 - assert g.has_edge(1, 2) - assert g.edges[1, 2]["relation"] == "is_a" + assert g.has_edge("1", "2") + assert g.edges["1", "2"]["relation"] == "is_a" def test_has_part_edge_present(self): g = build_chebi_graph(SAMPLE_OBO) # CHEBI:1 has_part CHEBI:3 - assert g.has_edge(1, 3) - assert g.edges[1, 3]["relation"] == "has_part" + assert g.has_edge("1", "3") + assert g.edges["1", "3"]["relation"] == "has_part" def test_total_edge_count(self): g = build_chebi_graph(SAMPLE_OBO) @@ -83,4 +83,4 @@ def test_xref_lines_do_not_break_parsing(self, tmp_path): "format-version: 1.2\n[Term]\nid: CHEBI:10\nname: test\nxref: Reaxys:123456\n" ) g = build_chebi_graph(obo_with_xrefs) - assert 10 in g.nodes + assert "10" in g.nodes