From 55585b929af3d5774856dfca647fc2ae220e561e Mon Sep 17 00:00:00 2001 From: chasem Date: Tue, 10 Mar 2026 08:27:30 -0500 Subject: [PATCH] updates necessary for dto rr analysis --- docs/tutorials/virtual_db_tutorial.ipynb | 10 +- docs/virtual_db.md | 14 +- docs/virtual_db_configuration.md | 3 +- tfbpapi/datacard.py | 2 +- tfbpapi/models.py | 42 +- tfbpapi/models_deprecated.py | 732 ----------------------- tfbpapi/tests/test_models.py | 93 +++ tfbpapi/tests/test_virtual_db.py | 215 +++++-- tfbpapi/virtual_db.py | 405 ++++++++----- 9 files changed, 570 insertions(+), 946 deletions(-) delete mode 100644 tfbpapi/models_deprecated.py diff --git a/docs/tutorials/virtual_db_tutorial.ipynb b/docs/tutorials/virtual_db_tutorial.ipynb index c1aecab..ffc73b9 100644 --- a/docs/tutorials/virtual_db_tutorial.ipynb +++ b/docs/tutorials/virtual_db_tutorial.ipynb @@ -147,13 +147,7 @@ "cell_type": "markdown", "id": "cell-3", "metadata": {}, - "source": [ - "## Initializing VirtualDB\n", - "\n", - "Creating a VirtualDB instance loads and validates the config but does\n", - "**not** download any data yet. Views are registered lazily on the first\n", - "`query()`, `tables()`, or `describe()` call." - ] + "source": "## Initializing VirtualDB\n\nCreating a VirtualDB instance loads and validates the config, downloads any\nnecessary data, and registers all views immediately." }, { "cell_type": "code", @@ -5100,4 +5094,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/docs/virtual_db.md b/docs/virtual_db.md index 9062618..ded59f4 100644 --- a/docs/virtual_db.md +++ b/docs/virtual_db.md @@ -25,27 +25,21 @@ and the [tutorial](tutorials/virtual_db_tutorial.ipynb) for usage examples. ## Advanced Usage -After any public method is called (e.g. `vdb.tables()`), the underlying DuckDB -connection is available as `vdb._db`. You can use `_db` to execute any SQL -on the database, eg creating more views, or creating a table in memory +The underlying DuckDB connection is available as `vdb._conn`. You can use +`_conn` to execute any SQL on the database, eg creating more views, or +creating a table in memory. Custom **views** created this way appear in `tables()`, `describe()`, and `get_fields()` automatically because those methods query DuckDB's `information_schema`. Custom **tables** do not appear in `tables()` (which only lists views), but are fully queryable via `vdb.query()`. -Call at least one public method first to ensure the connection is initialized -before accessing `_db` directly. - Example -- create a materialized analysis table:: - # Trigger view registration - vdb.tables() - # Create a persistent in-memory table from a complex query. # This example selects one "best" Hackett-2020 sample per regulator # using a priority system: ZEV+P > GEV+P > GEV+M. - vdb._db.execute(""" + vdb._conn.execute(""" CREATE OR REPLACE TABLE hackett_analysis_set AS WITH regulator_tiers AS ( SELECT diff --git a/docs/virtual_db_configuration.md b/docs/virtual_db_configuration.md index 42316d5..06897c3 100644 --- a/docs/virtual_db_configuration.md +++ b/docs/virtual_db_configuration.md @@ -255,8 +255,7 @@ for more detailed explanation of comparative datasets and composite IDs. ## Internal Structure VirtualDB uses an in-memory DuckDB database to construct a layered hierarchy -of SQL views over locally cached Parquet files. Views are created lazily on -first query and are not persisted to disk. +of SQL views over locally cached Parquet files. Views are created on initialization and are not persisted to disk. ### View Hierarchy diff --git a/tfbpapi/datacard.py b/tfbpapi/datacard.py index 734a5f3..52a573c 100644 --- a/tfbpapi/datacard.py +++ b/tfbpapi/datacard.py @@ -307,7 +307,7 @@ def _build_metadata_fields_map(self) -> None: ] break else: - self.logger.warning( + self.logger.info( "No metadata fields found for data config '%s' " "in repo '%s' -- no embedded metadata_fields and " "no metadata config with applies_to", diff --git a/tfbpapi/models.py b/tfbpapi/models.py index 4d77f02..c2c9501 100644 --- a/tfbpapi/models.py +++ b/tfbpapi/models.py @@ -9,6 +9,7 @@ """ +import logging from enum import Enum from functools import cached_property from pathlib import Path @@ -29,6 +30,9 @@ FactorAliases: TypeAlias = dict[str, dict[str, list[str | int | float | bool]]] +logger = logging.getLogger(__name__) + + class DatasetType(str, Enum): """Supported dataset types.""" @@ -761,6 +765,23 @@ def validate_factor_aliases(cls, v: FactorAliases) -> FactorAliases: ) return v + @model_validator(mode="after") + def validate_repositories_have_datasets(self) -> "MetadataConfig": + """ + Validate that every repository defines at least one dataset. + + :return: The validated MetadataConfig instance + :raises ValueError: If any repository has no datasets defined + + """ + for repo_id, repo_config in self.repositories.items(): + if not repo_config.dataset: + raise ValueError( + f"Repository '{repo_id}' must define at least one " + "dataset under the 'dataset' key." + ) + return self + @model_validator(mode="after") def validate_unique_db_names(self) -> "MetadataConfig": """ @@ -791,13 +812,19 @@ def validate_unique_db_names(self) -> "MetadataConfig": @model_validator(mode="before") @classmethod - def parse_repositories(cls, data: Any) -> dict[str, Any]: + def parse_config(cls, data: Any) -> dict[str, Any]: """ - Parse repository configurations from 'repositories' key. + Parse and validate all top-level sections of the VirtualDB configuration. + + Handles the four top-level sections: ``repositories`` (required), + ``factor_aliases``, ``missing_value_labels``, and ``description`` + (all optional). Logs an INFO message for each optional section that + is absent from the configuration. :param data: Raw configuration data - :return: Processed configuration with parsed repositories - :raises ValueError: If repositories are invalid or missing + :return: Processed configuration dict ready for Pydantic field validation + :raises ValueError: If ``repositories`` is missing or empty, or if + any repository config is invalid """ if not isinstance(data, dict): @@ -811,6 +838,13 @@ def parse_repositories(cls, data: Any) -> dict[str, Any]: "with at least one repository" ) + for optional_key in ("factor_aliases", "missing_value_labels", "description"): + if not data.get(optional_key): + logger.info( + "No '%s' section found in VirtualDB configuration.", + optional_key, + ) + # Parse each repository config repositories = {} for repo_id, repo_config in repositories_data.items(): diff --git a/tfbpapi/models_deprecated.py b/tfbpapi/models_deprecated.py deleted file mode 100644 index 6888579..0000000 --- a/tfbpapi/models_deprecated.py +++ /dev/null @@ -1,732 +0,0 @@ -""" -Pydantic models for dataset card validation and metadata configuration. - -These models provide minimal structure for parsing HuggingFace dataset cards while -remaining flexible enough to accommodate diverse experimental systems. Most fields use -extra="allow" to accept domain-specific additions without requiring code changes. - -Also includes models for VirtualDB metadata normalization configuration. - -""" - -from enum import Enum -from pathlib import Path -from typing import Any - -import yaml # type: ignore[import-untyped] -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator - - -class DatasetType(str, Enum): - """Supported dataset types.""" - - GENOMIC_FEATURES = "genomic_features" - ANNOTATED_FEATURES = "annotated_features" - GENOME_MAP = "genome_map" - METADATA = "metadata" - COMPARATIVE = "comparative" - - -class FeatureInfo(BaseModel): - """ - Information about a dataset feature/column. - - Minimal required fields with flexible dtype handling. - - """ - - name: str = Field(..., description="Column name in the data") - dtype: str | dict[str, Any] = Field( - ..., - description="Data type (string, int64, float64, etc.) or class_label dict", - ) - description: str = Field(..., description="Description of the field") - role: str | None = Field( - default=None, - description="Optional semantic role. 'experimental_condition' " - "has special behavior.", - ) - definitions: dict[str, Any] | None = Field( - default=None, - description="For experimental_condition fields: definitions per value", - ) - - -class PartitioningInfo(BaseModel): - """Partitioning configuration for datasets.""" - - enabled: bool = Field(default=False, description="Whether partitioning is enabled") - partition_by: list[str] | None = Field( - default=None, description="Partition column names" - ) - path_template: str | None = Field( - default=None, description="Path template for partitioned files" - ) - - -class DatasetInfo(BaseModel): - """Dataset structure information.""" - - features: list[FeatureInfo] = Field(..., description="Feature definitions") - partitioning: PartitioningInfo | None = Field( - default=None, description="Partitioning configuration" - ) - - -class DataFileInfo(BaseModel): - """Information about data files.""" - - split: str = Field(default="train", description="Dataset split name") - path: str = Field(..., description="Path to data file(s)") - - -class DatasetConfig(BaseModel): - """ - Configuration for a dataset within a repository. - - Uses extra="allow" to accept arbitrary experimental_conditions and other fields. - - """ - - config_name: str = Field(..., description="Unique configuration identifier") - description: str = Field(..., description="Human-readable description") - dataset_type: DatasetType = Field(..., description="Type of dataset") - default: bool = Field( - default=False, description="Whether this is the default config" - ) - applies_to: list[str] | None = Field( - default=None, description="Configs this metadata applies to" - ) - metadata_fields: list[str] | None = Field( - default=None, description="Fields for embedded metadata extraction" - ) - data_files: list[DataFileInfo] = Field(..., description="Data file information") - dataset_info: DatasetInfo = Field(..., description="Dataset structure information") - - model_config = ConfigDict(extra="allow") - - @field_validator("applies_to") - @classmethod - def applies_to_only_for_metadata(cls, v, info): - """Validate that applies_to is only used for metadata or comparative configs.""" - if v is not None: - dataset_type = info.data.get("dataset_type") - if dataset_type not in (DatasetType.METADATA, DatasetType.COMPARATIVE): - raise ValueError( - "applies_to field is only valid " - "for metadata and comparative dataset types" - ) - return v - - @field_validator("metadata_fields") - @classmethod - def metadata_fields_validation(cls, v): - """Validate metadata_fields usage.""" - if v is not None and len(v) == 0: - raise ValueError("metadata_fields cannot be empty list, use None instead") - return v - - -class DatasetCard(BaseModel): - """ - Complete dataset card model. - - Uses extra="allow" to accept arbitrary top-level metadata and - experimental_conditions. - - """ - - configs: list[DatasetConfig] = Field(..., description="Dataset configurations") - - model_config = ConfigDict(extra="allow") - - @field_validator("configs") - @classmethod - def configs_not_empty(cls, v): - """Ensure at least one config is present.""" - if not v: - raise ValueError("At least one dataset configuration is required") - return v - - @field_validator("configs") - @classmethod - def unique_config_names(cls, v): - """Ensure config names are unique.""" - names = [config.config_name for config in v] - if len(names) != len(set(names)): - raise ValueError("Configuration names must be unique") - return v - - @field_validator("configs") - @classmethod - def at_most_one_default(cls, v): - """Ensure at most one config is marked as default.""" - defaults = [config for config in v if config.default] - if len(defaults) > 1: - raise ValueError("At most one configuration can be marked as default") - return v - - def get_config_by_name(self, name: str) -> DatasetConfig | None: - """Get a configuration by name.""" - for config in self.configs: - if config.config_name == name: - return config - return None - - def get_configs_by_type(self, dataset_type: DatasetType) -> list[DatasetConfig]: - """Get all configurations of a specific type.""" - return [ - config for config in self.configs if config.dataset_type == dataset_type - ] - - def get_default_config(self) -> DatasetConfig | None: - """Get the default configuration if one exists.""" - defaults = [config for config in self.configs if config.default] - return defaults[0] if defaults else None - - def get_data_configs(self) -> list[DatasetConfig]: - """Get all non-metadata configurations.""" - return [ - config - for config in self.configs - if config.dataset_type != DatasetType.METADATA - ] - - def get_metadata_configs(self) -> list[DatasetConfig]: - """Get all metadata configurations.""" - return [ - config - for config in self.configs - if config.dataset_type == DatasetType.METADATA - ] - - -class ExtractedMetadata(BaseModel): - """Metadata extracted from datasets.""" - - config_name: str = Field(..., description="Source configuration name") - field_name: str = Field( - ..., description="Field name the metadata was extracted from" - ) - values: set[str] = Field(..., description="Unique values found") - extraction_method: str = Field(..., description="How the metadata was extracted") - - model_config = ConfigDict( - # Allow sets in JSON serialization - json_encoders={set: list} - ) - - -class MetadataRelationship(BaseModel): - """Relationship between a data config and its metadata.""" - - data_config: str = Field(..., description="Data configuration name") - metadata_config: str = Field(..., description="Metadata configuration name") - relationship_type: str = Field( - ..., description="Type of relationship (explicit, embedded)" - ) - - -# ============================================================================ -# VirtualDB Metadata Configuration Models -# ============================================================================ - - -class ComparativeAnalysis(BaseModel): - """ - Reference to a comparative dataset that includes this dataset. - - Comparative datasets relate samples across multiple source datasets. - This model specifies which comparative dataset references the current - dataset and through which field (via_field). - - Attributes: - repo: HuggingFace repository ID of the comparative dataset - dataset: Config name of the comparative dataset - via_field: Field in the comparative dataset containing composite - identifiers that reference this dataset's samples. - Format: "repo_id;config_name;sample_id" - - Example: - ```python - # In BrentLab/callingcards config - ComparativeAnalysis( - repo="BrentLab/yeast_comparative_analysis", - dataset="dto", - via_field="binding_id" - ) - # Means: dto dataset has a binding_id field with values like: - # "BrentLab/callingcards;annotated_features;123" - ``` - - """ - - repo: str = Field(..., description="Comparative dataset repository ID") - dataset: str = Field(..., description="Comparative dataset config name") - via_field: str = Field( - ..., description="Field containing composite sample identifiers" - ) - - -class PropertyMapping(BaseModel): - """ - Mapping specification for a single property. - - Attributes: - path: Optional dot-notation path to the property value. - For repo/config-level: relative to experimental_conditions - For field-level: relative to field definitions - When omitted with field specified, creates a column alias. - field: Optional field name for field-level properties. - When specified, looks in this field's definitions. - When omitted, looks in repo/config-level experimental_conditions. - expression: Optional SQL expression for derived/computed fields. - When specified, creates a computed column. - Cannot be used with field or path. - dtype: Optional data type specification for type conversion. - Supported values: 'string', 'numeric', 'bool'. - When specified, extracted values are converted to this type. - - Examples: - Field-level property with path: - PropertyMapping(field="condition", path="media.carbon_source") - - Repo/config-level property: - PropertyMapping(path="temperature_celsius") - - Field-level column alias (no path): - PropertyMapping(field="condition") - - Derived field with expression: - PropertyMapping(expression="dto_fdr < 0.05") - - """ - - field: str | None = Field(None, description="Field name for field-level properties") - path: str | None = Field(None, description="Dot-notation path to property") - expression: str | None = Field( - None, description="SQL expression for derived fields" - ) - dtype: str | None = Field( - None, description="Data type for conversion: 'string', 'numeric', or 'bool'" - ) - - @field_validator("path") - @classmethod - def validate_path(cls, v: str | None) -> str | None: - """Ensure path is not just whitespace if provided.""" - if v is not None and not v.strip(): - raise ValueError("path cannot be empty or whitespace") - return v.strip() if v else None - - @field_validator("field") - @classmethod - def validate_field(cls, v: str | None) -> str | None: - """Ensure field is not empty string if provided.""" - if v is not None and not v.strip(): - raise ValueError("field cannot be empty or whitespace") - return v.strip() if v else None - - @field_validator("expression") - @classmethod - def validate_expression(cls, v: str | None) -> str | None: - """Ensure expression is not empty string if provided.""" - if v is not None and not v.strip(): - raise ValueError("expression cannot be empty or whitespace") - return v.strip() if v else None - - @model_validator(mode="after") - def validate_at_least_one_specified(self) -> "PropertyMapping": - """Ensure at least one field type is specified and mutually exclusive.""" - if self.expression is not None: - if self.field is not None or self.path is not None: - raise ValueError( - "expression cannot be used with field or path - " - "derived fields are computed, not extracted" - ) - elif self.field is None and self.path is None: - raise ValueError( - "At least one of 'field', 'path', or 'expression' must be specified" - ) - return self - - -class DatasetVirtualDBConfig(BaseModel): - """ - VirtualDB configuration for a specific dataset within a repository. - - Attributes: - sample_id: Mapping for the sample identifier field (required for - primary datasets) - comparative_analyses: Optional list of comparative datasets that - reference this dataset - properties: Property mappings for this specific dataset (field names to - PropertyMapping) - - Example: - ```yaml - # In BrentLab/callingcards config - annotated_features: - sample_id: - field: sample_id - comparative_analyses: - - repo: BrentLab/yeast_comparative_analysis - dataset: dto - via_field: binding_id - regulator_locus_tag: - field: regulator_locus_tag - dto_fdr: # Field from comparative dataset, optional renaming - field: dto_fdr - ``` - - """ - - sample_id: PropertyMapping | None = Field( - None, description="Mapping for sample identifier field" - ) - comparative_analyses: list[ComparativeAnalysis] = Field( - default_factory=list, - description="Comparative datasets referencing this dataset", - ) - # Allow additional property mappings via extra fields - model_config = ConfigDict(extra="allow") - - @model_validator(mode="before") - @classmethod - def parse_property_mappings(cls, data: Any) -> Any: - """Parse extra fields as PropertyMapping objects.""" - if not isinstance(data, dict): - return data - - # Process all fields except sample_id and comparative_analyses - result = {} - for key, value in data.items(): - if key in ("sample_id", "comparative_analyses"): - # These are typed fields, let Pydantic handle them - result[key] = value - elif isinstance(value, dict): - # Assume it's a PropertyMapping - try: - result[key] = PropertyMapping.model_validate(value) - except Exception as e: - raise ValueError( - f"Invalid PropertyMapping for field '{key}': {e}" - ) from e - else: - # Already parsed or wrong type - result[key] = value - - return result - - -class RepositoryConfig(BaseModel): - """ - Configuration for a single repository. Eg BrentLab/harbison_2004. - - Attributes: - properties: Repo-wide property mappings that apply to all datasets - dataset: Dataset-specific configurations including sample_id, - comparative_analyses, and property mappings - - Example: - ```python - config = RepositoryConfig( - properties={ - "temperature_celsius": PropertyMapping(path="temperature_celsius") - }, - dataset={ - "dataset_name": DatasetVirtualDBConfig( - sample_id=PropertyMapping(field="sample_id"), - comparative_analyses=[ - ComparativeAnalysis( - repo="BrentLab/yeast_comparative_analysis", - dataset="dto", - via_field="binding_id" - ) - ], - # Additional property mappings via extra fields - **{"carbon_source": PropertyMapping( - field="condition", - path="media.carbon_source" - )} - ) - } - ) - ``` - - """ - - properties: dict[str, PropertyMapping] = Field( - default_factory=dict, description="Repo-wide property mappings" - ) - dataset: dict[str, DatasetVirtualDBConfig] | None = Field( - None, description="Dataset-specific configurations" - ) - - @model_validator(mode="before") - @classmethod - def parse_structure(cls, data: Any) -> Any: - """Parse raw dict structure into typed objects.""" - if not isinstance(data, dict): - return data - - # Extract and parse dataset section - dataset_section = data.get("dataset") - parsed_datasets: dict[str, DatasetVirtualDBConfig] = {} - - if dataset_section: - if not isinstance(dataset_section, dict): - raise ValueError("'dataset' key must contain a dict") - for dataset_name, config_dict in dataset_section.items(): - if not isinstance(config_dict, dict): - raise ValueError(f"Dataset '{dataset_name}' must contain a dict") - - # Parse DatasetVirtualDBConfig - # The config_dict may contain: - # - sample_id (PropertyMapping) - # - comparative_analyses (list[ComparativeAnalysis]) - # - Other fields as PropertyMappings (via extra="allow") - try: - parsed_datasets[dataset_name] = ( - DatasetVirtualDBConfig.model_validate(config_dict) - ) - except Exception as e: - raise ValueError( - f"Invalid configuration for dataset '{dataset_name}': {e}" - ) from e - - # Parse repo-wide properties (all keys except 'dataset') - parsed_properties = {} - for key, value in data.items(): - if key == "dataset": - continue - - try: - parsed_properties[key] = PropertyMapping.model_validate(value) - except Exception as e: - raise ValueError(f"Invalid repo-wide property '{key}': {e}") from e - - return {"properties": parsed_properties, "dataset": parsed_datasets} - - -class MetadataConfig(BaseModel): - """ - Configuration for building standardized metadata tables. - - Specifies optional alias mappings for normalizing factor levels across - heterogeneous datasets, plus property path mappings for each repository. - - Attributes: - factor_aliases: Optional mappings of standardized names to actual values. - Example: {"carbon_source": - {"glucose": ["D-glucose", "dextrose"]}} - missing_value_labels: Labels for missing values by property name - description: Human-readable descriptions for each property - repositories: Dict mapping repository IDs to their configurations - - Example: - ```yaml - repositories: - BrentLab/harbison_2004: - dataset: - harbison_2004: - carbon_source: - field: condition - path: media.carbon_source - - BrentLab/kemmeren_2014: - temperature: - path: temperature_celsius - dataset: - kemmeren_2014: - carbon_source: - path: media.carbon_source - - factor_aliases: - carbon_source: - glucose: ["D-glucose", "dextrose"] - galactose: ["D-galactose", "Galactose"] - - missing_value_labels: - carbon_source: "unspecified" - - description: - carbon_source: "Carbon source in growth media" - ``` - - """ - - factor_aliases: dict[str, dict[str, list[Any]]] = Field( - default_factory=dict, - description="Optional alias mappings for normalizing factor levels", - ) - missing_value_labels: dict[str, str] = Field( - default_factory=dict, - description="Labels for missing values by property name", - ) - description: dict[str, str] = Field( - default_factory=dict, - description="Human-readable descriptions for each property", - ) - repositories: dict[str, RepositoryConfig] = Field( - ..., description="Repository configurations keyed by repo ID" - ) - - @field_validator("missing_value_labels", mode="before") - @classmethod - def validate_missing_value_labels(cls, v: Any) -> dict[str, str]: - """Validate missing value labels structure, filtering out None values.""" - if not v: - return {} - if not isinstance(v, dict): - raise ValueError("missing_value_labels must be a dict") - # Filter out None values that may come from empty YAML values - return {k: val for k, val in v.items() if val is not None} - - @field_validator("description", mode="before") - @classmethod - def validate_description(cls, v: Any) -> dict[str, str]: - """Validate description structure, filtering out None values.""" - if not v: - return {} - if not isinstance(v, dict): - raise ValueError("description must be a dict") - # Filter out None values that may come from empty YAML values - return {k: val for k, val in v.items() if val is not None} - - @field_validator("factor_aliases") - @classmethod - def validate_factor_aliases( - cls, v: dict[str, dict[str, list[Any]]] - ) -> dict[str, dict[str, list[Any]]]: - """Validate factor alias structure.""" - # Empty is OK - aliases are optional - if not v: - return v - - for prop_name, aliases in v.items(): - if not isinstance(aliases, dict): - raise ValueError( - f"Property '{prop_name}' aliases must be a dict, " - f"got {type(aliases).__name__}" - ) - - # Validate each alias mapping - for alias_name, actual_values in aliases.items(): - if not isinstance(actual_values, list): - raise ValueError( - f"Alias '{alias_name}' for '{prop_name}' must map " - f"to a list of values" - ) - if not actual_values: - raise ValueError( - f"Alias '{alias_name}' for '{prop_name}' cannot " - f"have empty value list" - ) - for val in actual_values: - if not isinstance(val, (str, int, float, bool)): - raise ValueError( - f"Alias '{alias_name}' for '{prop_name}' contains " - f"invalid value type: {type(val).__name__}" - ) - - return v - - @model_validator(mode="before") - @classmethod - def parse_repositories(cls, data: Any) -> Any: - """Parse repository configurations from 'repositories' key.""" - if not isinstance(data, dict): - return data - - # Extract repositories from 'repositories' key - repositories_data = data.get("repositories", {}) - - if not repositories_data: - raise ValueError( - "Configuration must have a 'repositories' key " - "with at least one repository" - ) - - if not isinstance(repositories_data, dict): - raise ValueError("'repositories' key must contain a dict") - - repositories = {} - for repo_id, repo_config in repositories_data.items(): - try: - repositories[repo_id] = RepositoryConfig.model_validate(repo_config) - except Exception as e: - raise ValueError( - f"Invalid configuration for repository '{repo_id}': {e}" - ) from e - - return { - "factor_aliases": data.get("factor_aliases", {}), - "missing_value_labels": data.get("missing_value_labels", {}), - "description": data.get("description", {}), - "repositories": repositories, - } - - @classmethod - def from_yaml(cls, path: Path | str) -> "MetadataConfig": - """ - Load and validate configuration from YAML file. - - :param path: Path to YAML configuration file - :return: Validated MetadataConfig instance - :raises FileNotFoundError: If file doesn't exist - :raises ValueError: If configuration is invalid - - """ - path = Path(path) - - if not path.exists(): - raise FileNotFoundError(f"Configuration file not found: {path}") - - with open(path) as f: - data = yaml.safe_load(f) - - if not isinstance(data, dict): - raise ValueError("Configuration must be a YAML dict") - - return cls.model_validate(data) - - def get_repository_config(self, repo_id: str) -> RepositoryConfig | None: - """ - Get configuration for a specific repository. - - :param repo_id: Repository ID (e.g., "BrentLab/harbison_2004") - :return: RepositoryConfig instance or None if not found - - """ - return self.repositories.get(repo_id) - - def get_property_mappings( - self, repo_id: str, config_name: str - ) -> dict[str, PropertyMapping]: - """ - Get merged property mappings for a repo/dataset combination. - - Merges repo-wide and dataset-specific mappings, with dataset-specific taking - precedence. - - :param repo_id: Repository ID - :param config_name: Dataset/config name - :return: Dict mapping property names to PropertyMapping objects - - """ - repo_config = self.get_repository_config(repo_id) - if not repo_config: - return {} - - # Start with repo-wide properties - mappings: dict[str, PropertyMapping] = dict(repo_config.properties) - - # Override with dataset-specific properties - if repo_config.dataset and config_name in repo_config.dataset: - dataset_config = repo_config.dataset[config_name] - # DatasetVirtualDBConfig stores property mappings in model_extra - if hasattr(dataset_config, "model_extra") and dataset_config.model_extra: - mappings.update(dataset_config.model_extra) - - return mappings diff --git a/tfbpapi/tests/test_models.py b/tfbpapi/tests/test_models.py index 4506131..d78b538 100644 --- a/tfbpapi/tests/test_models.py +++ b/tfbpapi/tests/test_models.py @@ -16,6 +16,7 @@ DatasetType, ExtractedMetadata, FeatureInfo, + MetadataConfig, MetadataRelationship, PartitioningInfo, ) @@ -575,3 +576,95 @@ def test_metadata_relationship_creation(self): assert relationship.data_config == "binding_data" assert relationship.metadata_config == "experiment_metadata" assert relationship.relationship_type == "explicit" + + +# ------------------------------------------------------------------ +# Minimal valid YAML snippets reused across MetadataConfig tests +# ------------------------------------------------------------------ + +_MINIMAL_CONFIG = { + "repositories": { + "BrentLab/harbison": { + "dataset": { + "harbison_2004": { + "sample_id": {"field": "sample_id"}, + } + } + } + } +} + + +class TestMetadataConfig: + """Tests for MetadataConfig Pydantic model validation.""" + + def test_valid_minimal_config(self): + """Minimal config with one repo and one dataset parses successfully.""" + config = MetadataConfig.model_validate(_MINIMAL_CONFIG) + assert "BrentLab/harbison" in config.repositories + + def test_missing_repositories_key_raises(self): + """Config missing 'repositories' raises ValueError.""" + with pytest.raises((ValidationError, ValueError)): + MetadataConfig.model_validate({}) + + def test_empty_repositories_raises(self): + """Config with empty 'repositories' dict raises ValueError.""" + with pytest.raises((ValidationError, ValueError)): + MetadataConfig.model_validate({"repositories": {}}) + + def test_repository_with_no_dataset_raises(self): + """Repository with no 'dataset' key raises ValueError.""" + with pytest.raises((ValidationError, ValueError)): + MetadataConfig.model_validate({"repositories": {"BrentLab/harbison": {}}}) + + def test_optional_sections_absent_succeeds(self): + """Parsing succeeds when optional sections are absent.""" + config = MetadataConfig.model_validate(_MINIMAL_CONFIG) + assert config.factor_aliases == {} + assert config.missing_value_labels == {} + + def test_optional_sections_present(self): + """Optional sections are parsed correctly when present.""" + data = { + "repositories": { + "BrentLab/harbison": { + "dataset": { + "harbison_2004": { + "sample_id": {"field": "sample_id"}, + } + } + } + }, + "factor_aliases": {"carbon_source": {"glucose": ["glu", "dextrose"]}}, + "missing_value_labels": {"carbon_source": "unspecified"}, + } + config = MetadataConfig.model_validate(data) + assert "carbon_source" in config.factor_aliases + assert config.missing_value_labels != {} + + def test_duplicate_db_name_raises(self): + """Duplicate db_name across datasets raises ValueError.""" + with pytest.raises((ValidationError, ValueError)): + MetadataConfig.model_validate( + { + "repositories": { + "BrentLab/harbison": { + "dataset": { + "harbison_2004": { + "db_name": "shared_name", + "sample_id": {"field": "sample_id"}, + } + } + }, + "BrentLab/kemmeren": { + "dataset": { + "kemmeren_2014": { + "db_name": "shared_name", + "sample_id": {"field": "sample_id"}, + } + } + }, + } + } + ) diff --git a/tfbpapi/tests/test_virtual_db.py b/tfbpapi/tests/test_virtual_db.py index fa02695..6129880 100644 --- a/tfbpapi/tests/test_virtual_db.py +++ b/tfbpapi/tests/test_virtual_db.py @@ -15,7 +15,7 @@ import yaml # type: ignore from tfbpapi.datacard import DatasetSchema -from tfbpapi.models import MetadataConfig +from tfbpapi.models import DatasetType, MetadataConfig from tfbpapi.virtual_db import VirtualDB # ------------------------------------------------------------------ @@ -389,6 +389,7 @@ def _make_mock_datacard(repo_id): else: config_mock = MagicMock() config_mock.metadata_fields = None + config_mock.dataset_type = DatasetType.COMPARATIVE card.get_config.return_value = config_mock card.get_field_definitions.return_value = {} card.get_experimental_conditions.return_value = {} @@ -406,8 +407,6 @@ def vdb(config_path, parquet_dir, monkeypatch): for local testing.""" import tfbpapi.virtual_db as vdb_module - v = VirtualDB(config_path) - def _fake_resolve(self, repo_id, config_name): return parquet_dir.get((repo_id, config_name), []) @@ -417,7 +416,7 @@ def _fake_resolve(self, repo_id, config_name): "_cached_datacard", lambda repo_id, token=None: _make_mock_datacard(repo_id), ) - return v + return VirtualDB(config_path) # ------------------------------------------------------------------ @@ -430,12 +429,20 @@ class TestVirtualDBConfig: def test_init_loads_config(self, config_path, monkeypatch): """Test that config loads without error.""" + monkeypatch.setattr(VirtualDB, "_load_datacards", lambda self: None) + monkeypatch.setattr(VirtualDB, "_validate_datacards", lambda self: None) + monkeypatch.setattr(VirtualDB, "_update_cache", lambda self: None) + monkeypatch.setattr(VirtualDB, "_register_all_views", lambda self: None) v = VirtualDB(config_path) assert v.config is not None assert v.token is None def test_init_with_token(self, config_path, monkeypatch): """Test token is stored.""" + monkeypatch.setattr(VirtualDB, "_load_datacards", lambda self: None) + monkeypatch.setattr(VirtualDB, "_validate_datacards", lambda self: None) + monkeypatch.setattr(VirtualDB, "_update_cache", lambda self: None) + monkeypatch.setattr(VirtualDB, "_register_all_views", lambda self: None) v = VirtualDB(config_path, token="tok123") assert v.token == "tok123" @@ -444,22 +451,18 @@ def test_init_missing_file(self): with pytest.raises(FileNotFoundError): VirtualDB("/nonexistent/path.yaml") - def test_repr_before_views(self, config_path): - """Test repr before views are registered.""" - v = VirtualDB(config_path) - r = repr(v) - assert "VirtualDB" in r - assert "views not yet registered" in r - - def test_repr_after_views(self, vdb): - """Test repr after views are registered.""" - vdb.tables() # triggers view registration + def test_repr(self, vdb): + """Test repr shows repo, dataset, and view counts.""" r = repr(vdb) assert "VirtualDB" in r assert "views)" in r - def test_db_name_map(self, config_path): + def test_db_name_map(self, config_path, monkeypatch): """Test that _db_name_map resolves db_name correctly.""" + monkeypatch.setattr(VirtualDB, "_load_datacards", lambda self: None) + monkeypatch.setattr(VirtualDB, "_validate_datacards", lambda self: None) + monkeypatch.setattr(VirtualDB, "_update_cache", lambda self: None) + monkeypatch.setattr(VirtualDB, "_register_all_views", lambda self: None) v = VirtualDB(config_path) assert "harbison" in v._db_name_map assert "kemmeren" in v._db_name_map @@ -599,13 +602,16 @@ def test_yaml_round_trip(self): tags_b = config.get_tags("BrentLab/repo_b", "dataset_b") assert tags_b == {"type": "perturbation"} - def _make_vdb(self, yaml_str: str, tmp_path) -> VirtualDB: - + def _make_vdb(self, yaml_str: str, tmp_path, monkeypatch) -> VirtualDB: + monkeypatch.setattr(VirtualDB, "_load_datacards", lambda self: None) + monkeypatch.setattr(VirtualDB, "_validate_datacards", lambda self: None) + monkeypatch.setattr(VirtualDB, "_update_cache", lambda self: None) + monkeypatch.setattr(VirtualDB, "_register_all_views", lambda self: None) p = tmp_path / "config.yaml" p.write_text(yaml_str) return VirtualDB(str(p)) - def test_vdb_get_tags_returns_merged(self, tmp_path): + def test_vdb_get_tags_returns_merged(self, tmp_path, monkeypatch): """VirtualDB.get_tags() returns merged repo+dataset tags by db_name.""" vdb = self._make_vdb( """ @@ -623,11 +629,12 @@ def test_vdb_get_tags_returns_merged(self, tmp_path): assay: chip-chip """, tmp_path, + monkeypatch, ) tags = vdb.get_tags("harbison") assert tags == {"assay": "chip-chip", "organism": "yeast"} - def test_vdb_get_tags_unknown_name_returns_empty(self, tmp_path): + def test_vdb_get_tags_unknown_name_returns_empty(self, tmp_path, monkeypatch): """VirtualDB.get_tags() returns empty dict for unknown db_name.""" vdb = self._make_vdb( """ @@ -640,11 +647,12 @@ def test_vdb_get_tags_unknown_name_returns_empty(self, tmp_path): field: sample_id """, tmp_path, + monkeypatch, ) assert vdb.get_tags("nonexistent") == {} - def test_vdb_get_tags_no_views_needed(self, tmp_path): - """VirtualDB.get_tags() works before any views are registered.""" + def test_vdb_get_tags_no_views_needed(self, tmp_path, monkeypatch): + """VirtualDB.get_tags() returns correct tags from config.""" vdb = self._make_vdb( """ repositories: @@ -658,15 +666,13 @@ def test_vdb_get_tags_no_views_needed(self, tmp_path): field: sample_id """, tmp_path, + monkeypatch, ) - assert not vdb._views_registered tags = vdb.get_tags("harbison") assert tags == {"assay": "binding"} - assert not vdb._views_registered - def test_vdb_get_datasets(self, tmp_path): - """VirtualDB.get_datasets() returns sorted db_names without registering - views.""" + def test_vdb_get_datasets(self, tmp_path, monkeypatch): + """VirtualDB.get_datasets() returns sorted db_names from config.""" vdb = self._make_vdb( """ repositories: @@ -684,10 +690,9 @@ def test_vdb_get_datasets(self, tmp_path): field: sample_id """, tmp_path, + monkeypatch, ) - assert not vdb._views_registered assert vdb.get_datasets() == ["harbison", "kemmeren"] - assert not vdb._views_registered # ------------------------------------------------------------------ @@ -696,7 +701,7 @@ def test_vdb_get_datasets(self, tmp_path): class TestViewRegistration: - """Tests for lazy view creation.""" + """Tests for view creation.""" def test_raw_views_created(self, vdb): """Test that raw per-dataset views exist.""" @@ -1031,6 +1036,8 @@ class TestEdgeCases: def test_no_parquet_files(self, tmp_path, monkeypatch): """Test graceful handling when no parquet files are found.""" + import tfbpapi.virtual_db as vdb_module + config = { "repositories": { "BrentLab/empty": { @@ -1046,22 +1053,60 @@ def test_no_parquet_files(self, tmp_path, monkeypatch): with open(p, "w") as f: yaml.dump(config, f) - v = VirtualDB(p) - def _fake_resolve(self, repo_id, config_name): return [] monkeypatch.setattr(VirtualDB, "_resolve_parquet_files", _fake_resolve) + monkeypatch.setattr( + vdb_module, + "_cached_datacard", + lambda repo_id, token=None: _make_mock_datacard(repo_id), + ) # Should not raise; just have no views + v = VirtualDB(p) views = v.tables() assert "empty_data" not in views - def test_lazy_init(self, config_path): - """Test that DuckDB connection is not created until needed.""" - v = VirtualDB(config_path) - assert v._conn is None - assert not v._views_registered + def test_links_with_non_comparative_dataset_type_raises( + self, tmp_path, monkeypatch + ): + """Dataset with 'links' but datacard dataset_type != comparative raises + ValueError.""" + import tfbpapi.virtual_db as vdb_module + + config = { + "repositories": { + "BrentLab/harbison": { + "dataset": { + "harbison_2004": { + "sample_id": {"field": "sample_id"}, + "links": { + "sample_id": [["BrentLab/primary", "primary_data"]] + }, + } + } + } + } + } + p = tmp_path / "config.yaml" + with open(p, "w") as f: + yaml.dump(config, f) + + non_comparative_card = _make_mock_datacard("BrentLab/harbison") + cfg_mock = MagicMock() + cfg_mock.dataset_type = DatasetType.ANNOTATED_FEATURES + non_comparative_card.get_config.return_value = cfg_mock + + monkeypatch.setattr(VirtualDB, "_resolve_parquet_files", lambda *a: []) + monkeypatch.setattr( + vdb_module, + "_cached_datacard", + lambda repo_id, token=None: non_comparative_card, + ) + + with pytest.raises(ValueError, match="comparative"): + VirtualDB(p) # ------------------------------------------------------------------ @@ -1128,8 +1173,6 @@ def test_non_default_sample_id(self, tmp_path, monkeypatch): is_partitioned=False, ) - v = VirtualDB(config_path) - monkeypatch.setattr( VirtualDB, "_resolve_parquet_files", @@ -1141,16 +1184,101 @@ def test_non_default_sample_id(self, tmp_path, monkeypatch): lambda repo_id, token=None: mock_card, ) - # Meta view should have experiment_id + regulator + v = VirtualDB(config_path) + + # Meta view should rename experiment_id -> sample_id meta_df = v.query("SELECT * FROM custom_meta") - assert "experiment_id" in meta_df.columns + assert "sample_id" in meta_df.columns + assert "experiment_id" not in meta_df.columns + assert list(meta_df["sample_id"]) == [100, 200] or set( + meta_df["sample_id"] + ) == {100, 200} assert len(meta_df) == 2 # 2 distinct samples - # Enriched raw view should JOIN on experiment_id + # Enriched raw view should also expose sample_id raw_df = v.query("SELECT * FROM custom") - assert "experiment_id" in raw_df.columns + assert "sample_id" in raw_df.columns + assert "experiment_id" not in raw_df.columns assert len(raw_df) == 4 # all rows + def test_non_default_sample_id_with_collision(self, tmp_path, monkeypatch): + """When parquet has both gm_id (sample) and sample_id (other col), gm_id is + renamed to sample_id and sample_id is preserved as sample_id_orig.""" + import tfbpapi.virtual_db as vdb_module + + config = { + "repositories": { + "TestOrg/collision": { + "dataset": { + "collision_data": { + "db_name": "collision", + "sample_id": {"field": "gm_id"}, + "regulator": {"field": "regulator"}, + } + } + } + } + } + config_path = tmp_path / "config.yaml" + with open(config_path, "w") as f: + yaml.dump(config, f) + + # Parquet has gm_id (the real sample id) AND a literal sample_id col + df = pd.DataFrame( + { + "gm_id": [1, 1, 2, 2], + "sample_id": [101, 101, 102, 102], # some other field + "regulator": ["TF1", "TF1", "TF2", "TF2"], + "target": ["G1", "G2", "G1", "G2"], + "score": [1.0, 2.0, 3.0, 4.0], + } + ) + parquet_path = tmp_path / "collision.parquet" + files = { + ("TestOrg/collision", "collision_data"): [_write_parquet(parquet_path, df)], + } + + mock_card = MagicMock() + mock_card.get_metadata_fields.return_value = ["regulator"] + mock_card.get_field_definitions.return_value = {} + mock_card.get_experimental_conditions.return_value = {} + mock_card.get_dataset_schema.return_value = DatasetSchema( + data_columns={"gm_id", "sample_id", "target", "score"}, + metadata_columns={"regulator"}, + join_columns=set(), + metadata_source="embedded", + external_metadata_config=None, + is_partitioned=False, + ) + + monkeypatch.setattr( + VirtualDB, + "_resolve_parquet_files", + lambda self, repo_id, cn: files.get((repo_id, cn), []), + ) + monkeypatch.setattr( + vdb_module, + "_cached_datacard", + lambda repo_id, token=None: mock_card, + ) + + v = VirtualDB(config_path) + + # Meta view: gm_id -> sample_id, original sample_id -> sample_id_orig + meta_df = v.query("SELECT * FROM collision_meta") + assert "sample_id" in meta_df.columns + assert "sample_id_orig" in meta_df.columns + assert "gm_id" not in meta_df.columns + assert set(meta_df["sample_id"]) == {1, 2} + assert set(meta_df["sample_id_orig"]) == {101, 102} + + # Raw view same behavior + raw_df = v.query("SELECT * FROM collision") + assert "sample_id" in raw_df.columns + assert "sample_id_orig" in raw_df.columns + assert "gm_id" not in raw_df.columns + assert len(raw_df) == 4 + def test_get_sample_id_field_dataset_level(self): """Dataset-level sample_id takes precedence.""" config = MetadataConfig.model_validate( @@ -1307,7 +1435,6 @@ def test_external_metadata_join(self, tmp_path, monkeypatch): is_partitioned=False, ) - v = VirtualDB(config_file) monkeypatch.setattr( VirtualDB, "_resolve_parquet_files", @@ -1319,7 +1446,7 @@ def test_external_metadata_join(self, tmp_path, monkeypatch): lambda repo_id, token=None: card, ) - # Trigger view registration + v = VirtualDB(config_file) tables = v.tables() assert "chip" in tables assert "chip_meta" in tables diff --git a/tfbpapi/virtual_db.py b/tfbpapi/virtual_db.py index 86c9ea3..eaa77d3 100644 --- a/tfbpapi/virtual_db.py +++ b/tfbpapi/virtual_db.py @@ -6,8 +6,8 @@ https://brentlab.github.io/tfbpapi/huggingface_datacard/. Next, a developer can create a virtualDB configuration file that describes which huggingface repos and datasets to use, a set of common fields, datasets that contain comparative analytics, and more. -VirtualDB, this code, then uses DuckDB to construct views that are lazily created -over Parquet files cached locally. For primary datasets, VirtualDB creates metadata +VirtualDB, this code, then uses DuckDB to construct views over Parquet files cached +locally on initialization. For primary datasets, VirtualDB creates metadata views (one row per sample with derived columns) and full data views (measurement-level data joined to metadata). For comparative analysis datasets, VirtualDB creates expanded views that parse composite ID fields into ``_source`` (aliased to the configured @@ -52,7 +52,7 @@ from duckdb import BinderException from tfbpapi.datacard import DataCard, DatasetSchema -from tfbpapi.models import MetadataConfig +from tfbpapi.models import DatasetType, MetadataConfig logger = logging.getLogger(__name__) @@ -169,45 +169,30 @@ def __init__( config_path: Path | str, token: str | None = None, duckdb_connection: duckdb.DuckDBPyConnection | None = None, - views_registered: bool = False, - lazy: bool = True, ): """ Initialize VirtualDB with configuration. + Creates the DuckDB connection and registers all views immediately. + :param config_path: Path to YAML configuration file :param token: Optional HuggingFace token for private datasets :param duckdb_connection: Optional DuckDB connection. If provided, views will be registered on this connection instead of creating a new in-memory database. - Note that this provides a method of using a persistent database file. If - this isn't provided, then the duckDB connection is in-memory. - :param views_registered: If True, skip view registration (assumes views are - already registered on the provided duckdb_connection). This is useful when - reusing a connection across multiple VirtualDB instances with the same - config. - :param lazy: If True, delay DuckDB connection and view registration until first - query. Set to False to register views immediately on initialization. This is - intended to be used when creating a persistent duckDB connection. If the - views are registered immediately on initialization, then for any other - instances of VirtualDB that are initialized with the same duckDB connection - and config, the views will already be registered and available for querying. + This provides a method of using a persistent database file. If not provided, + an in-memory DuckDB connection is created. :raises FileNotFoundError: If config file does not exist - :raises ValueError: If configuration is invalid or if views_registered=True is - set when lazy=False + :raises ValueError: If configuration is invalid """ - if not lazy and views_registered: - raise ValueError( - "Cannot set views_registered=True when lazy=False. " - "If lazy=False, views will be registered immediately on initialization." - ) self.config = MetadataConfig.from_yaml(config_path) self.token = token - # Instantiate without creating a connection, if no connection is provided. - # the connection is created when needed by calling self._ensure_sql_views() - self._conn: duckdb.DuckDBPyConnection | None = duckdb_connection - self._views_registered = views_registered + self._conn: duckdb.DuckDBPyConnection = ( + duckdb_connection + if duckdb_connection is not None + else duckdb.connect(":memory:") + ) # db_name -> (repo_id, config_name) self._db_name_map = self._build_db_name_map() @@ -215,17 +200,10 @@ def __init__( # Prepared queries: name -> sql self._prepared_queries: dict[str, str] = {} - # If not lazy, create the DuckDB connection and register views immediately. - if not lazy: - self._ensure_sql_views() - - @property - def _db(self) -> duckdb.DuckDBPyConnection: - """Return the DuckDB connection, asserting it is initialized.""" - assert self._conn is not None, ( - "DuckDB connection not initialized. " "Call _ensure_sql_views() first." - ) - return self._conn + self._load_datacards() + self._validate_datacards() + self._update_cache() + self._register_all_views() # ------------------------------------------------------------------ # Public API @@ -259,16 +237,14 @@ def query(self, sql: str, **params: Any) -> pd.DataFrame: df = vdb.query("top", n=10) """ - self._ensure_sql_views() - # param `sql` may be a prepared query name, a raw sql statement, or # a parameterized sql statement that is not prepared. If it exists as a key # in the _prepared_queries dict, we use the prepared sql. Otherwise, we # use the sql as passed to query(). resolved = self._prepared_queries.get(sql, sql) if params: - return self._db.execute(resolved, params).fetchdf() - return self._db.execute(resolved).fetchdf() + return self._conn.execute(resolved, params).fetchdf() + return self._conn.execute(resolved).fetchdf() def prepare(self, name: str, sql: str, overwrite: bool = False) -> None: """ @@ -294,7 +270,7 @@ def prepare(self, name: str, sql: str, overwrite: bool = False) -> None: df = vdb.query("glucose_regs", cs="glucose", min_n=2) """ - self._ensure_sql_views() + if name in self._list_views() and not overwrite: error_msg = ( f"Prepared-query name '{name}' collides with " @@ -312,7 +288,7 @@ def tables(self) -> list[str]: :return: Sorted list of view names """ - self._ensure_sql_views() + return sorted(self._list_views()) def describe(self, table: str | None = None) -> pd.DataFrame: @@ -324,15 +300,15 @@ def describe(self, table: str | None = None) -> pd.DataFrame: ``column_type`` """ - self._ensure_sql_views() + if table is not None: - df = self._db.execute(f"DESCRIBE {table}").fetchdf() + df = self._conn.execute(f"DESCRIBE {table}").fetchdf() df.insert(0, "table", table) return df frames = [] for view in sorted(self._list_views()): - df = self._db.execute(f"DESCRIBE {view}").fetchdf() + df = self._conn.execute(f"DESCRIBE {view}").fetchdf() df.insert(0, "table", view) frames.append(df) if not frames: @@ -347,9 +323,9 @@ def get_fields(self, table: str | None = None) -> list[str]: :return: Sorted list of column names """ - self._ensure_sql_views() + if table is not None: - cols = self._db.execute( + cols = self._conn.execute( f"SELECT column_name FROM information_schema.columns " f"WHERE table_name = '{table}'" ).fetchdf() @@ -357,7 +333,7 @@ def get_fields(self, table: str | None = None) -> list[str]: all_cols: set[str] = set() for view in self._list_views(): - cols = self._db.execute( + cols = self._conn.execute( f"SELECT column_name FROM information_schema.columns " f"WHERE table_name = '{view}'" ).fetchdf() @@ -374,14 +350,14 @@ def get_common_fields(self) -> list[str]: :return: Sorted list of common column names """ - self._ensure_sql_views() + meta_views = self._get_primary_meta_view_names() if not meta_views: return [] sets = [] for view in meta_views: - cols = self._db.execute( + cols = self._conn.execute( f"SELECT column_name FROM information_schema.columns " f"WHERE table_name = '{view}'" ).fetchdf() @@ -430,42 +406,82 @@ def get_tags(self, db_name: str) -> dict[str, str]: return self.config.get_tags(repo_id, config_name) # ------------------------------------------------------------------ - # Lazy initialisation + # Initialisation phases # ------------------------------------------------------------------ - def _ensure_sql_views(self) -> None: - """Create DuckDB connection and register all views on first call.""" - if self._views_registered: - return - self._conn = duckdb.connect(":memory:") - self._register_all_views() - self._views_registered = True + def _load_datacards(self) -> None: + """ + Fetch (or load from cache) the DataCard for every distinct repo. - def _register_all_views(self) -> None: - """Orchestrate view registration in dependency order.""" - # 1. Raw per-dataset views (internal ___parquet - # plus public for primary datasets only) - for db_name, (repo_id, config_name) in self._db_name_map.items(): - comparative = self._is_comparative(repo_id, config_name) - self._register_raw_view( - db_name, - repo_id, - config_name, - parquet_only=comparative, - ) + Populates ``self._datacards`` keyed by ``repo_id``. Failures are + logged as warnings and the repo is omitted from the dict so that + subsequent phases can skip it gracefully. - # 1b. Resolve external metadata parquet views. - # When a data config's metadata lives in a separate HF config - # (applies_to), register its parquet as ___metadata_parquet. - # All information is derived from DataCard YAML -- no DuckDB - # introspection needed. + """ + self._datacards: dict[str, DataCard] = {} + seen_repos: set[str] = set() + for repo_id, _ in self._db_name_map.values(): + if repo_id in seen_repos: + continue + seen_repos.add(repo_id) + try: + self._datacards[repo_id] = _cached_datacard(repo_id, token=self.token) + except Exception as exc: + logger.warning( + "Could not load datacard for repo '%s': %s", + repo_id, + exc, + ) + + def _validate_datacards(self) -> None: + """ + Cross-check the VirtualDB config against the loaded datacards. + + Checks that every dataset with a ``links`` field in the VirtualDB + config has ``dataset_type: comparative`` in its HuggingFace datacard. + Also resolves ``self._dataset_schemas`` and + ``self._external_meta_configs`` (keyed by ``db_name``) for use by + ``_update_cache`` and ``_register_all_views``. + + :raises ValueError: If a dataset with ``links`` does not have + ``dataset_type: comparative`` in its datacard. + + """ self._dataset_schemas: dict[str, DatasetSchema] = {} - self._external_meta_views: dict[str, str] = {} + # db_name -> external metadata config_name (for applies_to datasets) + self._external_meta_configs: dict[str, str] = {} + for db_name, (repo_id, config_name) in self._db_name_map.items(): - if self._is_comparative(repo_id, config_name): + repo_cfg = self.config.repositories.get(repo_id) + ds_cfg = ( + repo_cfg.dataset.get(config_name) + if repo_cfg and repo_cfg.dataset + else None + ) + card = self._datacards.get(repo_id) + + # Validate comparative dataset_type agreement. + if ds_cfg and ds_cfg.links: + if card is not None: + dc_config = card.get_config(config_name) + if ( + dc_config is not None + and dc_config.dataset_type != DatasetType.COMPARATIVE + ): + raise ValueError( + f"Dataset '{config_name}' in repo '{repo_id}' has " + f"'links' in the VirtualDB config, indicating a " + f"comparative dataset, but the HuggingFace datacard " + f"declares dataset_type='{dc_config.dataset_type}'. " + f"Update the datacard to use dataset_type: comparative." + ) + continue # comparative datasets need no schema resolution + + # Resolve dataset schema and external metadata config for + # primary datasets. + if card is None: continue try: - card = _cached_datacard(repo_id, token=self.token) schema = card.get_dataset_schema(config_name) except Exception as exc: logger.warning( @@ -478,26 +494,69 @@ def _register_all_views(self) -> None: if schema is not None: self._dataset_schemas[db_name] = schema if ( - schema is None - or schema.metadata_source != "external" - or not schema.external_metadata_config + schema is not None + and schema.metadata_source == "external" + and schema.external_metadata_config ): - continue - meta_view = f"__{db_name}_metadata_parquet" - files = self._resolve_parquet_files( - repo_id, schema.external_metadata_config + self._external_meta_configs[db_name] = schema.external_metadata_config + + def _update_cache(self) -> None: + """ + Download (or locate cached) Parquet files for all dataset configs. + + Populates ``self._parquet_files`` keyed by ``db_name``. For datasets + with external metadata (identified during ``_validate_datacards``), + also downloads those files and stores them under the key + ``"___meta"`` so ``_register_all_views`` can read them + without further network calls. + + """ + self._parquet_files: dict[str, list[str]] = {} + for db_name, (repo_id, config_name) in self._db_name_map.items(): + files = self._resolve_parquet_files(repo_id, config_name) + self._parquet_files[db_name] = files + + for db_name, ext_config_name in self._external_meta_configs.items(): + repo_id, _ = self._db_name_map[db_name] + files = self._resolve_parquet_files(repo_id, ext_config_name) + self._parquet_files[f"__{db_name}_meta"] = files + + def _register_all_views(self) -> None: + """ + Register all DuckDB views in dependency order. + + Expects ``self._parquet_files``, ``self._dataset_schemas``, and + ``self._external_meta_configs`` to have been populated by the earlier + init phases. No network or disk access occurs here. + + """ + # 1. Raw per-dataset views (internal ___parquet + # plus public for primary datasets only) + for db_name, (repo_id, config_name) in self._db_name_map.items(): + comparative = self._is_comparative(repo_id, config_name) + self._register_raw_view( + db_name, + parquet_only=comparative, ) + + # 2. External metadata parquet views. + # When a data config's metadata lives in a separate HF config + # (applies_to), register its parquet as ___metadata_parquet. + self._external_meta_views: dict[str, str] = {} + for db_name, ext_config_name in self._external_meta_configs.items(): + meta_view = f"__{db_name}_metadata_parquet" + files = self._parquet_files.get(f"__{db_name}_meta", []) if not files: logger.warning( "No parquet files for external metadata config " - "'%s' in repo '%s'", - schema.external_metadata_config, - repo_id, + "'%s' (db_name '%s') -- skipping external metadata view", + ext_config_name, + db_name, ) continue files_sql = ", ".join(f"'{f}'" for f in files) try: - self._db.execute( + self._conn.execute( f"CREATE OR REPLACE VIEW {meta_view} AS " f"SELECT * FROM read_parquet([{files_sql}])" ) @@ -510,24 +569,18 @@ def _register_all_views(self) -> None: continue self._external_meta_views[db_name] = meta_view - # 2. Metadata views for primary datasets (_meta) - # This is based on the metadata defined in the datacard, - # and includes any additional derived columns based on the - # virtualDB config passed in at initialization. Note that - # this is joined onto the raw view in the next step. + # 3. Metadata views for primary datasets (_meta) for db_name, (repo_id, config_name) in self._db_name_map.items(): if not self._is_comparative(repo_id, config_name): self._register_meta_view(db_name, repo_id, config_name) - # 3. Replace primary raw views with join to _meta so + # 4. Replace primary raw views with join to _meta so # derived columns (e.g. carbon_source) are available for db_name, (repo_id, config_name) in self._db_name_map.items(): if not self._is_comparative(repo_id, config_name): self._enrich_raw_view(db_name) - # 4. Comparative expanded views (pre-parsed composite IDs) - # These build directly on ___parquet since - # comparative datasets have no _meta or enriched raw view. + # 5. Comparative expanded views (pre-parsed composite IDs) for db_name, (repo_id, config_name) in self._db_name_map.items(): repo_cfg = self.config.repositories.get(repo_id) if not repo_cfg or not repo_cfg.dataset: @@ -617,13 +670,11 @@ def _resolve_parquet_files(self, repo_id: str, config_name: str) -> list[str]: def _register_raw_view( self, db_name: str, - repo_id: str, - config_name: str, *, parquet_only: bool = False, ) -> None: """ - Register a raw DuckDB view over Parquet files. + Register a raw DuckDB view over pre-resolved Parquet files. Creates an internal ``___parquet`` view that reads directly from the Parquet files. For primary datasets, also @@ -633,33 +684,44 @@ def _register_raw_view( For comparative datasets, only the internal parquet view is created; the public view is the ``_expanded`` view instead. + Parquet files must have been resolved by ``_update_cache`` + before this method is called. + :param db_name: View name - :param repo_id: Repository ID - :param config_name: Configuration name :param parquet_only: If True, only create the internal ``___parquet`` view (no public ````). """ - files = self._resolve_parquet_files(repo_id, config_name) + files = self._parquet_files.get(db_name, []) if not files: logger.warning( - "No parquet files for %s/%s -- skipping view '%s'", - repo_id, - config_name, + "No parquet files for db_name '%s' -- skipping view", db_name, ) return files_sql = ", ".join(f"'{f}'" for f in files) parquet_sql = f"SELECT * FROM read_parquet([{files_sql}])" - self._db.execute( + self._conn.execute( f"CREATE OR REPLACE VIEW __{db_name}_parquet AS " f"{parquet_sql}" ) if not parquet_only: - self._db.execute( - f"CREATE OR REPLACE VIEW {db_name} AS " - f"SELECT * FROM __{db_name}_parquet" - ) + sample_col = self._get_sample_id_col(db_name) + if sample_col == "sample_id": + public_select = f"SELECT * FROM __{db_name}_parquet" + else: + raw_cols = self._get_view_columns(f"__{db_name}_parquet") + parts: list[str] = [] + for col in raw_cols: + if col == sample_col: + parts.append(f"{col} AS sample_id") + elif col == "sample_id": + parts.append(f"{col} AS sample_id_orig") + else: + parts.append(col) + cols_sql = ", ".join(parts) + public_select = f"SELECT {cols_sql} FROM __{db_name}_parquet" + self._conn.execute(f"CREATE OR REPLACE VIEW {db_name} AS {public_select}") def _register_meta_view(self, db_name: str, repo_id: str, config_name: str) -> None: """ @@ -693,12 +755,8 @@ def _register_meta_view(self, db_name: str, repo_id: str, config_name: str) -> N # Pull ext_meta_view early -- needed for both meta_cols and # FROM clause construction. - schema: DatasetSchema | None = getattr(self, "_dataset_schemas", {}).get( - db_name - ) - ext_meta_view: str | None = getattr(self, "_external_meta_views", {}).get( - db_name - ) + schema: DatasetSchema | None = self._dataset_schemas.get(db_name) + ext_meta_view: str | None = self._external_meta_views.get(db_name) is_external = ( ext_meta_view is not None @@ -757,16 +815,32 @@ def qualify(col: str) -> str: return f"m.{col}" return f"d.{col}" - # Build SELECT: sample_id + metadata cols (deduplicated) + # Build SELECT: sample_id + metadata cols (deduplicated). + # If the configured sample_id column differs from "sample_id", + # rename it so all views expose a consistent "sample_id" column. + # If the parquet also has a literal "sample_id" column, preserve + # it as "sample_id_orig" to avoid losing data. seen: set[str] = set() select_parts: list[str] = [] + rename_sample = sample_col != "sample_id" def add_col(col: str) -> None: if col not in seen: seen.add(col) - select_parts.append(qualify(col)) + if rename_sample and col == sample_col: + select_parts.append(f"{qualify(col)} AS sample_id") + elif rename_sample and col == "sample_id": + select_parts.append(f"{qualify(col)} AS sample_id_orig") + else: + select_parts.append(qualify(col)) add_col(sample_col) + # When renaming, check if the parquet source also has a literal + # "sample_id" column; if so, preserve it as "sample_id_orig". + if rename_sample: + source_cols = set(self._get_view_columns(parquet_view)) + if "sample_id" in source_cols: + add_col("sample_id") for col in meta_cols: add_col(col) @@ -798,7 +872,7 @@ def add_col(col: str) -> None: f"SELECT DISTINCT {cols_sql} FROM {from_clause}" ) try: - self._db.execute(sql) + self._conn.execute(sql) except BinderException as exc: raise BinderException( f"Failed to create meta view '{db_name}_meta'.\n" @@ -825,20 +899,62 @@ def _enrich_raw_view(self, db_name: str) -> None: if not self._view_exists(meta_name) or not self._view_exists(parquet_name): return - raw_cols = set(self._get_view_columns(parquet_name)) + raw_cols_list = self._get_view_columns(parquet_name) + raw_cols = set(raw_cols_list) meta_cols = set(self._get_view_columns(meta_name)) - extra_cols = meta_cols - raw_cols + + sample_col = self._get_sample_id_col(db_name) + rename_sample = sample_col != "sample_id" + + # Columns to pull from _meta that aren't already in raw parquet, + # accounting for the sample_id rename: when renaming, "sample_id" + # will appear in meta_cols (as the renamed column) but not in + # raw_cols (which has the original name), so we must exclude it + # from extra_cols since the rename in the raw SELECT already + # provides it. + if rename_sample: + # "sample_id" and "sample_id_orig" come from the raw SELECT + # rename, not from meta + extra_cols = meta_cols - raw_cols - {"sample_id", "sample_id_orig"} + else: + extra_cols = meta_cols - raw_cols if not extra_cols: + # No derived columns to add -- the view created in + # _register_raw_view (which already handles the rename) + # is sufficient. return - sample_col = self._get_sample_id_col(db_name) - extra_select = ", ".join(f"m.{c}" for c in sorted(extra_cols)) - self._db.execute( + if rename_sample: + # Build explicit SELECT to rename the sample column + raw_parts: list[str] = [] + for col in raw_cols_list: + if col == sample_col: + raw_parts.append(f"r.{col} AS sample_id") + elif col == "sample_id": + raw_parts.append(f"r.{col} AS sample_id_orig") + else: + raw_parts.append(f"r.{col}") + raw_select = ", ".join(raw_parts) + else: + raw_select = "r.*" + + if extra_cols: + extra_select = ", ".join(f"m.{c}" for c in sorted(extra_cols)) + full_select = f"{raw_select}, {extra_select}" + else: + full_select = raw_select + + if rename_sample: + join_clause = f"JOIN {meta_name} m ON r.{sample_col} = m.sample_id" + else: + join_clause = f"JOIN {meta_name} m USING ({sample_col})" + + self._conn.execute( f"CREATE OR REPLACE VIEW {db_name} AS " - f"SELECT r.*, {extra_select} " + f"SELECT {full_select} " f"FROM {parquet_name} r " - f"JOIN {meta_name} m USING ({sample_col})" + f"{join_clause}" ) def _get_view_columns(self, view: str) -> list[str]: @@ -850,7 +966,7 @@ def _get_view_columns(self, view: str) -> list[str]: which DuckDB may evaluate lazily. """ - df = self._db.execute(f"DESCRIBE {view}").fetchdf() + df = self._conn.execute(f"DESCRIBE {view}").fetchdf() return df["column_name"].tolist() def _get_sample_id_col(self, db_name: str) -> str: @@ -880,7 +996,9 @@ def _resolve_metadata_fields( """ try: - card = _cached_datacard(repo_id, token=self.token) + card = self._datacards.get(repo_id) or _cached_datacard( + repo_id, token=self.token + ) return card.get_metadata_fields(config_name) except Exception: logger.error( @@ -939,7 +1057,9 @@ def _resolve_property_columns( raw_cols: set[str] = set() try: - card = _cached_datacard(repo_id, token=self.token) + card = self._datacards.get(repo_id) or _cached_datacard( + repo_id, token=self.token + ) except Exception as exc: logger.warning( "Could not load DataCard for %s: %s", @@ -1218,7 +1338,7 @@ def _register_comparative_expanded_view( return cols_sql = ", ".join(extra_cols) - self._db.execute( + self._conn.execute( f"CREATE OR REPLACE VIEW {db_name}_expanded AS " f"SELECT *, {cols_sql} FROM {parquet_view}" ) @@ -1237,7 +1357,7 @@ def _is_comparative(self, repo_id: str, config_name: str) -> bool: def _list_views(self) -> list[str]: """Return list of public views (excludes internal __ prefixed).""" - df = self._db.execute( + df = self._conn.execute( "SELECT table_name FROM information_schema.tables " "WHERE table_schema = 'main' AND table_type = 'VIEW'" ).fetchdf() @@ -1245,7 +1365,7 @@ def _list_views(self) -> list[str]: def _view_exists(self, name: str) -> bool: """Check whether a view is registered (including internal).""" - df = self._db.execute( + df = self._conn.execute( "SELECT table_name FROM information_schema.tables " "WHERE table_schema = 'main' AND table_type = 'VIEW' " f"AND table_name = '{name}'" @@ -1285,14 +1405,9 @@ def __repr__(self) -> str: """String representation.""" n_repos = len(self.config.repositories) n_datasets = len(self._db_name_map) - if self._views_registered: - n_views = len(self._list_views()) - return ( - f"VirtualDB({n_repos} repos, " - f"{n_datasets} datasets, " - f"{n_views} views)" - ) + n_views = len(self._list_views()) return ( f"VirtualDB({n_repos} repos, " - f"{n_datasets} datasets, views not yet registered)" + f"{n_datasets} datasets, " + f"{n_views} views)" )