diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index e295a3ed..c62c82f9 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -940,7 +940,7 @@ def _extract_class_hierarchy(self, data_path: str) -> "nx.DiGraph": def _graph_to_raw_dataset(self, graph: "nx.DiGraph") -> pd.DataFrame: """ Converts the graph to a raw dataset. - Uses the graph created by `_extract_class_hierarchy` method to extract the + Uses the graph created by chebi_utils to extract the raw data in Dataframe format with additional columns corresponding to each multi-label class. Args: @@ -951,21 +951,6 @@ def _graph_to_raw_dataset(self, graph: "nx.DiGraph") -> pd.DataFrame: """ pass - @abstractmethod - def select_classes(self, g: "nx.DiGraph", *args, **kwargs) -> List: - """ - Selects classes from the dataset based on a specified criteria. - - Args: - g (nx.Graph): The graph representing the dataset. - *args: Additional positional arguments. - **kwargs: Additional keyword arguments. - - Returns: - List: A sorted list of node IDs that meet the specified criteria. - """ - pass - def save_processed(self, data: pd.DataFrame, filename: str) -> None: """ Save the processed dataset to a pickle file. @@ -1123,120 +1108,6 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """ pass - def get_test_split( - self, df: pd.DataFrame, seed: Optional[int] = None - ) -> Tuple[pd.DataFrame, pd.DataFrame]: - """ - Split the input DataFrame into training and testing sets based on multilabel stratified sampling. - - This method uses MultilabelStratifiedShuffleSplit to split the data such that the distribution of labels - in the training and testing sets is approximately the same. The split is based on the "labels" column - in the DataFrame. - - Args: - df (pd.DataFrame): The input DataFrame containing the data to be split. It must contain a column - named "labels" with the multilabel data. - seed (int, optional): The random seed to be used for reproducibility. Default is None. - - Returns: - Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the training set and testing set DataFrames. - - Raises: - ValueError: If the DataFrame does not contain a column named "labels". - """ - from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit - from sklearn.model_selection import StratifiedShuffleSplit - - print("Get test data split") - - labels_list = df["labels"].tolist() - - if len(labels_list[0]) > 1: - splitter = MultilabelStratifiedShuffleSplit( - n_splits=1, test_size=self.test_split, random_state=seed - ) - else: - splitter = StratifiedShuffleSplit( - n_splits=1, test_size=self.test_split, random_state=seed - ) - - train_indices, test_indices = next(splitter.split(labels_list, labels_list)) - - df_train = df.iloc[train_indices] - df_test = df.iloc[test_indices] - return df_train, df_test - - def get_train_val_splits_given_test( - self, df: pd.DataFrame, test_df: pd.DataFrame, seed: int = None - ) -> Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: - """ - Split the dataset into train and validation sets, given a test set. - Use test set (e.g., loaded from another source or generated in get_test_split), to avoid overlap - - Args: - df (pd.DataFrame): The original dataset. - test_df (pd.DataFrame): The test dataset. - seed (int, optional): The random seed to be used for reproducibility. Default is None. - - Returns: - Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: A dictionary containing train and - validation sets if self.use_inner_cross_validation is True, otherwise a tuple containing the train - and validation DataFrames. The keys are the names of the train and validation sets, and the values - are the corresponding DataFrames. - """ - from iterstrat.ml_stratifiers import ( - MultilabelStratifiedKFold, - MultilabelStratifiedShuffleSplit, - ) - from sklearn.model_selection import StratifiedShuffleSplit - - print("Split dataset into train / val with given test set") - - test_ids = test_df["ident"].tolist() - df_trainval = df[~df["ident"].isin(test_ids)] - labels_list_trainval = df_trainval["labels"].tolist() - - if self.use_inner_cross_validation: - folds = {} - kfold = MultilabelStratifiedKFold( - n_splits=self.inner_k_folds, random_state=seed - ) - for fold, (train_ids, val_ids) in enumerate( - kfold.split( - labels_list_trainval, - labels_list_trainval, - ) - ): - df_validation = df_trainval.iloc[val_ids] - df_train = df_trainval.iloc[train_ids] - folds[self.raw_file_names_dict[f"fold_{fold}_train"]] = df_train - folds[self.raw_file_names_dict[f"fold_{fold}_validation"]] = ( - df_validation - ) - - return folds - - if len(labels_list_trainval[0]) > 1: - splitter = MultilabelStratifiedShuffleSplit( - n_splits=1, - test_size=self.validation_split / (1 - self.test_split), - random_state=seed, - ) - else: - splitter = StratifiedShuffleSplit( - n_splits=1, - test_size=self.validation_split / (1 - self.test_split), - random_state=seed, - ) - - train_indices, validation_indices = next( - splitter.split(labels_list_trainval, labels_list_trainval) - ) - - df_validation = df_trainval.iloc[validation_indices] - df_train = df_trainval.iloc[train_indices] - return df_train, df_validation - def _retrieve_splits_from_csv(self) -> None: """ Retrieve previously saved data splits from splits.csv file or from provided file path. diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index d8296530..b22f90b3 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -6,13 +6,11 @@ import os import random from abc import ABC -from collections import OrderedDict from itertools import cycle, permutations, product -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Generator, List, Literal, Optional import numpy as np import pandas as pd -import torch from rdkit import Chem from tqdm import tqdm @@ -20,24 +18,8 @@ from chebai.preprocessing.datasets.base import _DynamicDataset if TYPE_CHECKING: - import fastobo import networkx as nx -# exclude some entities from the dataset because the violate disjointness axioms -CHEBI_BLACKLIST = [ - 194026, - 144321, - 156504, - 167175, - 167174, - 167178, - 183506, - 74635, - 3311, - 190439, - 92386, -] - class _ChEBIDataExtractor(_DynamicDataset, ABC): """ @@ -60,13 +42,12 @@ class _ChEBIDataExtractor(_DynamicDataset, ABC): # ---- Index for columns of processed `data.pkl` (derived from `_graph_to_raw_dataset` method) ------ # "id" at row index 0 - # "name" at row index 1 - # "SMILES" at row index 2 - # "mol" at row index 3 - # labels starting from row index 4 + # "mol" at row index 1 + # labels starting from row index 2 _ID_IDX: int = 0 - _DATA_REPRESENTATION_IDX: int = 3 - _LABELS_START_IDX: int = 4 + _DATA_REPRESENTATION_IDX: int = 1 + _LABELS_START_IDX: int = 2 + THRESHOLD: int = None def __init__( self, @@ -175,7 +156,8 @@ def _load_chebi(self, version: Optional[int] = None) -> str: Returns: str: The file path of the loaded ChEBI ontology. """ - import requests + if version is None: + version = self.chebi_version if version is None: version = self.chebi_version @@ -186,12 +168,9 @@ def _load_chebi(self, version: Optional[int] = None) -> str: print( f"Missing raw ChEBI data related for version v{version}. Downloading..." ) - if version < 245: - url = f"https://ftp.ebi.ac.uk/pub/databases/chebi/archive/chebi_legacy/archive/rel{version}/ontology/chebi.obo" - else: - url = f"https://ftp.ebi.ac.uk/pub/databases/chebi/archive/rel{version}/ontology/chebi.obo" - r = requests.get(url, allow_redirects=True) - open(chebi_path, "wb").write(r.content) + from chebi_utils import download_chebi_obo + + download_chebi_obo(version, dest_dir=self.raw_dir, filename=chebi_name) return chebi_path def _load_sdf(self, version: Optional[int] = None) -> str: @@ -204,10 +183,6 @@ def _load_sdf(self, version: Optional[int] = None) -> str: Returns: str: The file path of the loaded ChEBI SDF file. """ - import requests - import gzip - import shutil - if version is None: version = self.chebi_version @@ -215,15 +190,9 @@ def _load_sdf(self, version: Optional[int] = None) -> str: sdf_path = os.path.join(self.raw_dir, sdf_name) if not os.path.isfile(sdf_path): print(f"Missing raw SDF data related to version v{version}. Downloading...") - if version < 245: - url = f"https://ftp.ebi.ac.uk/pub/databases/chebi/archive/chebi_legacy/archive/rel{version}/ontology/chebi.obo" - else: - url = f"https://ftp.ebi.ac.uk/pub/databases/chebi/archive/rel{version}/SDF/chebi.sdf.gz" - r = requests.get(url, allow_redirects=True, stream=True) - open(sdf_path + ".gz", "wb").write(r.content) - with gzip.open(sdf_path + ".gz", "rb") as f_in: - with open(sdf_path, "wb") as f_out: - shutil.copyfileobj(f_in, f_out) + from chebi_utils import download_chebi_sdf + + download_chebi_sdf(version, dest_dir=self.raw_dir, filename=sdf_name) return sdf_path def _extract_class_hierarchy(self, data_path: str) -> "nx.DiGraph": @@ -232,48 +201,28 @@ def _extract_class_hierarchy(self, data_path: str) -> "nx.DiGraph": Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with fields/terms from the chebi term documents from `.obo` file. + Uses :func:`chebi_utils.build_chebi_graph` for OBO parsing and graph construction. + Args: data_path (str): The path to the ChEBI ontology. Returns: - nx.DiGraph: The class hierarchy. + nx.DiGraph: The class hierarchy (transitive closure, edges directed parent → child). """ - import fastobo - import networkx as nx + from chebi_utils import build_chebi_graph, get_hierarchy_subgraph - with open(data_path, encoding="utf-8") as chebi: - chebi = "\n".join(line for line in chebi if not line.startswith("xref:")) + full_graph = get_hierarchy_subgraph(build_chebi_graph(data_path)) - elements = [] - for term_doc in fastobo.loads(chebi): - if ( - term_doc - and isinstance(term_doc.id, fastobo.id.PrefixedIdent) - and term_doc.id.prefix == "CHEBI" - ): - term_dict = term_callback(term_doc) - if term_dict and ( - not self.subset - or ( - "subset" in term_dict - and term_dict["subset"] is not None - and term_dict["subset"][0] == self.subset[0] - ) # match 3:STAR to 3_STAR, 3star, 3_star, etc. - ): - elements.append(term_dict) - - g = nx.DiGraph() - for n in elements: - g.add_node(n["id"], **n) - - # Only take the edges which connects the existing nodes, to avoid internal creation of obsolete nodes - # https://github.com/ChEB-AI/python-chebai/pull/55#issuecomment-2386654142 - g.add_edges_from( - [(p, q["id"]) for q in elements for p in q["parents"] if g.has_node(p)] - ) - - print("Compute transitive closure") - return nx.transitive_closure_dag(g) + # Filter by subset if specified + if self.subset: + nodes_to_keep = [ + n + for n, d in full_graph.nodes(data=True) + if d.get("subset") is not None and d["subset"][0] == self.subset[0] + # match 3:STAR to 3_STAR, 3star, 3_star, etc. + ] + full_graph = full_graph.subgraph(nodes_to_keep).copy() + return full_graph def _graph_to_raw_dataset(self, g: "nx.DiGraph") -> pd.DataFrame: """ @@ -281,66 +230,25 @@ def _graph_to_raw_dataset(self, g: "nx.DiGraph") -> pd.DataFrame: Uses the graph created by `_extract_class_hierarchy` method to extract the raw data in Dataframe format with additional columns corresponding to each multi-label class. + Uses :func:`chebi_utils.sdf_extractor.extract_molecules` for SDF parsing. + Args: g (nx.DiGraph): The class hierarchy graph. Returns: pd.DataFrame: The raw dataset created from the graph. """ - import networkx as nx - smiles = nx.get_node_attributes(g, "smiles") - names = nx.get_node_attributes(g, "name") - - print(f"Processing {g}") + # Extract mol objects from SDF using chebi-utils + from chebi_utils import build_labeled_dataset, extract_molecules - molecules, smiles_list = zip( - *( - (n, smiles) - for n, smiles in ((n, smiles.get(n)) for n in smiles.keys()) - if smiles - ) - ) - data = OrderedDict(id=molecules) # `id` column at index 0 - data["name"] = [ - names.get(node) for node in molecules - ] # `name` column at index 1 - data["SMILES"] = smiles_list # `SMILES` (data representation) column at index 2 - - # # `mol` (RDKit Mol object) column at index 3 - from chembl_structure_pipeline.standardizer import ( - parse_molblock, - ) + sdf_path = os.path.join(self.raw_dir, self.raw_file_names_dict["sdf"]) + mol_df = extract_molecules(sdf_path) + data, labels = build_labeled_dataset(g, mol_df, self.THRESHOLD) - with open( - os.path.join(self.raw_dir, self.raw_file_names_dict["sdf"]), "rb" - ) as f: - # split input into blocks separated by "$$$$" - blocks = f.read().decode("utf-8").split("$$$$\n") - id_to_mol = dict() - for molfile in tqdm(blocks, desc="Processing SDF molecules"): - if "" not in molfile: - print(f"Skipping molfile without ChEBI ID: {molfile[:30]}...") - continue - ident = int(molfile.split("")[1].split(">")[0].split("CHEBI:")[1]) - # use same parsing strategy as CHEBI: github.com/chembl/libRDChEBI/blob/main/libRDChEBI/formats.py - mol = parse_molblock(molfile) - if mol is None: - print(f"Failed to parse molfile for CHEBI:{ident}") - continue - mol = sanitize_molecule(mol) - id_to_mol[ident] = mol - data["mol"] = [id_to_mol.get(node) for node in molecules] - - # Labels columns from index 4 onwards - for n in self.select_classes(g): - data[n] = [ - ((n in g.predecessors(node)) or (n == node)) for node in molecules - ] - - data = pd.DataFrame(data) - data = data[~data["mol"].isnull()] - data = data[~data["name"].isin(CHEBI_BLACKLIST)] + filename = "classes.txt" + with open(os.path.join(self.processed_dir_main, filename), "wt") as fout: + fout.writelines(str(label) + "\n" for label in labels) return data @@ -491,85 +399,26 @@ def _load_dict(self, input_file_path: str) -> Generator[dict[str, Any], None, No for feat, labels, ident in zip(features, all_labels, idents): yield dict(features=feat, labels=labels, ident=ident) - # ------------------------------ Phase: Dynamic Splits ----------------------------------- def _get_data_splits(self) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """ Loads encoded/transformed data and generates training, validation, and test splits. - - This method first loads encoded data from a file named `data.pt`, which is derived from either - `chebi_version` or `chebi_version_train`. It then splits the data into training, validation, and test sets. - - If `chebi_version_train` is provided: - - Loads additional encoded data from `chebi_version_train`. - - Splits this data into training and validation sets, while using the test set from `chebi_version`. - - Prunes the test set from `chebi_version` to include only labels that exist in `chebi_version_train`. - - If `chebi_version_train` is not provided: - - Splits the data from `chebi_version` into training, validation, and test sets without modification. - - Raises: - FileNotFoundError: If the required `data.pt` file(s) do not exist. Ensure that `prepare_data` - and/or `setup` methods have been called to generate the dataset files. - - Returns: - Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: A tuple containing three DataFrames: - - Training set - - Validation set - - Test set """ - try: - filename = self.processed_file_names_dict["data"] - data_chebi_version = self.load_processed_data_from_file(filename) - except FileNotFoundError: - raise FileNotFoundError( - "File data.pt doesn't exists. " - "Please call 'prepare_data' and/or 'setup' methods to generate the dataset files" - ) - df_chebi_version = pd.DataFrame(data_chebi_version) - train_df_chebi_ver, df_test_chebi_ver = self.get_test_split( - df_chebi_version, seed=self.dynamic_data_split_seed - ) + filename = self.processed_file_names_dict["data"] + data = self.load_processed_data_from_file(filename) + df_data = pd.DataFrame(data) - if self.chebi_version_train is not None: - # Load encoded data derived from "chebi_version_train" - try: - filename_train = ( - self._chebi_version_train_obj.processed_file_names_dict["data"] - ) - data_chebi_train_version = torch.load( - os.path.join( - self._chebi_version_train_obj.processed_dir, filename_train - ), - weights_only=False, - ) - except FileNotFoundError: - raise FileNotFoundError( - f"File data.pt doesn't exists related to chebi_version_train {self.chebi_version_train}." - f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files" - ) + from chebi_utils import create_multilabel_splits - df_chebi_train_version = pd.DataFrame(data_chebi_train_version) - # Get train/val split of data based on "chebi_version_train", but - # using test set from "chebi_version" - df_train, df_val = self.get_train_val_splits_given_test( - df_chebi_train_version, - df_test_chebi_ver, - seed=self.dynamic_data_split_seed, - ) - # Modify test set from "chebi_version" to only include the labels that - # exists in "chebi_version_train", all other entries remains same. - df_test = self._setup_pruned_test_set(df_test_chebi_ver) - else: - # Get all splits based on "chebi_version" - df_train, df_val = self.get_train_val_splits_given_test( - train_df_chebi_ver, - df_test_chebi_ver, - seed=self.dynamic_data_split_seed, - ) - df_test = df_test_chebi_ver - - return df_train, df_val, df_test + splits = create_multilabel_splits( + df_data, + self._LABELS_START_IDX, + 1 - self.validation_split - self.test_split, + self.validation_split, + self.test_split, + self.dynamic_data_split_seed, + ) + return splits["train"], splits["val"], splits["test"] def _setup_pruned_test_set( self, df_test_chebi_version: pd.DataFrame @@ -666,7 +515,7 @@ def processed_dir(self) -> str: @property def raw_file_names_dict(self) -> dict: - return {"chebi": "chebi.obo", "sdf": "chebi.sdf"} + return {"chebi": "chebi.obo", "sdf": "chebi.sdf.gz"} @property def processed_main_file_names_dict(self) -> dict: @@ -714,7 +563,6 @@ class ChEBIOverX(_ChEBIDataExtractor): """ READER: dr.ChemDataReader = dr.ChemDataReader - THRESHOLD: int = None @property def _name(self) -> str: @@ -726,51 +574,6 @@ def _name(self) -> str: """ return f"ChEBI{self.THRESHOLD}" - def select_classes(self, g: "nx.DiGraph", *args, **kwargs) -> List: - """ - Selects classes from the ChEBI dataset based on the number of successors meeting a specified threshold. - - This method iterates over the nodes in the graph, counting the number of successors for each node. - Nodes with a number of successors greater than or equal to the defined threshold are selected. - - Note: - The input graph must be transitive closure of a directed acyclic graph. - - Args: - g (nx.Graph): The graph representing the dataset. - *args: Additional positional arguments (not used). - **kwargs: Additional keyword arguments (not used). - - Returns: - List: A sorted list of node IDs that meet the successor threshold criteria. - - Side Effects: - Writes the list of selected nodes to a file named "classes.txt" in the specified processed directory. - - Notes: - - The `THRESHOLD` attribute should be defined in the subclass of this class. - - Nodes without a 'smiles' attribute are ignored in the successor count. - """ - import networkx as nx - - smiles = nx.get_node_attributes(g, "smiles") - nodes = list( - sorted( - { - node - for node in g.nodes - if sum( - 1 if smiles[s] is not None else 0 for s in g.successors(node) - ) - >= self.THRESHOLD - } - ) - ) - filename = "classes.txt" - with open(os.path.join(self.processed_dir_main, filename), "wt") as fout: - fout.writelines(str(node) + "\n" for node in nodes) - return nodes - class ChEBIOverXDeepSMILES(ChEBIOverX): """ @@ -983,102 +786,6 @@ class ChEBIOver100Fingerprints(ChEBIOverXFingerprints, ChEBIOver100): pass -def chebi_to_int(s: str) -> int: - """ - Converts a ChEBI term string representation to an integer ID. - - Args: - - s (str): A ChEBI term string, e.g., "CHEBI:12345". - - Returns: - - int: The integer ID extracted from the ChEBI term string. - """ - return int(s[s.index(":") + 1 :]) - - -def term_callback(doc: "fastobo.term.TermFrame") -> Union[Dict, bool]: - """ - Extracts information from a ChEBI term document. - This function takes a ChEBI term document as input and extracts relevant information such as the term ID, parents, - parts, name, and SMILES string. It returns a dictionary containing the extracted information. - - Args: - - doc: A ChEBI term document. - - Returns: - A dictionary containing the following keys: - - "id": The ID of the ChEBI term. - - "parents": A list of parent term IDs. - - "has_part": A set of term IDs representing the parts of the ChEBI term. - - "name": The name of the ChEBI term. - - "smiles": The SMILES string associated with the ChEBI term, if available. - """ - import fastobo - - parts = set() - parents = [] - name = None - smiles = None - subset = None - for clause in doc: - if isinstance(clause, fastobo.term.PropertyValueClause): - t = clause.property_value - # chemrof:smiles_string is the new annotation property, chebi/smiles is the old one (see https://chembl.blogspot.com/2025/07/chebi-20-data-products.html) - if ( - str(t.relation) == "chemrof:smiles_string" - or str(t.relation) == "http://purl.obolibrary.org/obo/chebi/smiles" - ): - assert smiles is None - smiles = t.value - # in older chebi versions, smiles strings are synonyms - # e.g. synonym: "[F-].[Na+]" RELATED SMILES [ChEBI] - elif isinstance(clause, fastobo.term.SynonymClause): - if "SMILES" in clause.raw_value(): - assert smiles is None - smiles = clause.raw_value().split('"')[1] - elif isinstance(clause, fastobo.term.RelationshipClause): - if str(clause.typedef) == "has_part": - parts.add(chebi_to_int(str(clause.term))) - elif isinstance(clause, fastobo.term.IsAClause): - parents.append(chebi_to_int(str(clause.term))) - elif isinstance(clause, fastobo.term.NameClause): - name = str(clause.name) - elif isinstance(clause, fastobo.term.SubsetClause): - subset = str(clause.subset) - - if isinstance(clause, fastobo.term.IsObsoleteClause): - if clause.obsolete: - # if the term document contains clause as obsolete as true, skips this document. - return False - - return { - "id": chebi_to_int(str(doc.id)), - "parents": parents, - "has_part": parts, - "name": name, - "smiles": smiles, - "subset": subset, - } - - -def sanitize_molecule(mol: Chem.Mol) -> Chem.Mol: - # mirror ChEBI molecule processing - from chembl_structure_pipeline.standardizer import update_mol_valences - - mol = update_mol_valences(mol) - Chem.SanitizeMol( - mol, - sanitizeOps=Chem.SanitizeFlags.SANITIZE_FINDRADICALS - | Chem.SanitizeFlags.SANITIZE_KEKULIZE - | Chem.SanitizeFlags.SANITIZE_SETAROMATICITY - | Chem.SanitizeFlags.SANITIZE_SETCONJUGATION - | Chem.SanitizeFlags.SANITIZE_SETHYBRIDIZATION - | Chem.SanitizeFlags.SANITIZE_SYMMRINGS, - catchErrors=True, - ) - return mol - - if __name__ == "__main__": dataset = ChEBIOver50(chebi_version=248, subset="3_STAR") dataset.prepare_data() diff --git a/pyproject.toml b/pyproject.toml index 715f4555..a253c1f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,13 +35,12 @@ dev = [ "requests", "scikit-learn", "scipy", - "fastobo", "selfies", "jsonargparse[signatures]>=4.17", "omegaconf", "deepsmiles", - "iterative-stratification", "torchmetrics", + "chebi-utils", ] linters = [