diff --git a/.gitignore b/.gitignore index 884098a30..e37e092cc 100644 --- a/.gitignore +++ b/.gitignore @@ -195,3 +195,8 @@ datajoint.json # Test outputs *_test_summary.txt + +# Swap files +*.swp +*.swo +*~ diff --git a/docker-compose.yaml b/docker-compose.yaml index 2c48ffd10..23fd773c1 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -24,6 +24,19 @@ services: timeout: 30s retries: 5 interval: 15s + postgres: + image: postgres:${POSTGRES_VER:-15} + environment: + - POSTGRES_PASSWORD=${PG_PASS:-password} + - POSTGRES_USER=${PG_USER:-postgres} + - POSTGRES_DB=${PG_DB:-test} + ports: + - "5432:5432" + healthcheck: + test: [ "CMD-SHELL", "pg_isready -U postgres" ] + timeout: 30s + retries: 5 + interval: 15s minio: image: minio/minio:${MINIO_VER:-RELEASE.2025-02-28T09-55-16Z} environment: @@ -52,6 +65,8 @@ services: depends_on: db: condition: service_healthy + postgres: + condition: service_healthy minio: condition: service_healthy environment: @@ -61,6 +76,10 @@ services: - DJ_TEST_HOST=db - DJ_TEST_USER=datajoint - DJ_TEST_PASSWORD=datajoint + - DJ_PG_HOST=postgres + - DJ_PG_USER=postgres + - DJ_PG_PASS=password + - DJ_PG_PORT=5432 - S3_ENDPOINT=minio:9000 - S3_ACCESS_KEY=datajoint - S3_SECRET_KEY=datajoint diff --git a/pyproject.toml b/pyproject.toml index 7cd06d786..fd33dfd53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,7 +89,7 @@ test = [ "pytest-cov", "requests", "graphviz", - "testcontainers[mysql,minio]>=4.0", + "testcontainers[mysql,minio,postgres]>=4.0", "polars>=0.20.0", "pyarrow>=14.0.0", ] @@ -98,6 +98,7 @@ test = [ s3 = ["s3fs>=2023.1.0"] gcs = ["gcsfs>=2023.1.0"] azure = ["adlfs>=2023.1.0"] +postgres = ["psycopg2-binary>=2.9.0"] polars = ["polars>=0.20.0"] arrow = ["pyarrow>=14.0.0"] test = [ @@ -105,7 +106,8 @@ test = [ "pytest-cov", "requests", "s3fs>=2023.1.0", - "testcontainers[mysql,minio]>=4.0", + "testcontainers[mysql,minio,postgres]>=4.0", + "psycopg2-binary>=2.9.0", "polars>=0.20.0", "pyarrow>=14.0.0", ] @@ -227,6 +229,9 @@ ignore-words-list = "rever,numer,astroid" markers = [ "requires_mysql: marks tests as requiring MySQL database (deselect with '-m \"not requires_mysql\"')", "requires_minio: marks tests as requiring MinIO object storage (deselect with '-m \"not requires_minio\"')", + "mysql: marks tests that run on MySQL backend (select with '-m mysql')", + "postgresql: marks tests that run on PostgreSQL backend (select with '-m postgresql')", + "backend_agnostic: marks tests that should pass on all backends (auto-marked for parameterized tests)", ] diff --git a/src/datajoint/adapters/__init__.py b/src/datajoint/adapters/__init__.py new file mode 100644 index 000000000..5115a982a --- /dev/null +++ b/src/datajoint/adapters/__init__.py @@ -0,0 +1,54 @@ +""" +Database adapter registry for DataJoint. + +This module provides the adapter factory function and exports all adapters. +""" + +from __future__ import annotations + +from .base import DatabaseAdapter +from .mysql import MySQLAdapter +from .postgres import PostgreSQLAdapter + +__all__ = ["DatabaseAdapter", "MySQLAdapter", "PostgreSQLAdapter", "get_adapter"] + +# Adapter registry mapping backend names to adapter classes +ADAPTERS: dict[str, type[DatabaseAdapter]] = { + "mysql": MySQLAdapter, + "postgresql": PostgreSQLAdapter, + "postgres": PostgreSQLAdapter, # Alias for postgresql +} + + +def get_adapter(backend: str) -> DatabaseAdapter: + """ + Get adapter instance for the specified database backend. + + Parameters + ---------- + backend : str + Backend name: 'mysql', 'postgresql', or 'postgres'. + + Returns + ------- + DatabaseAdapter + Adapter instance for the specified backend. + + Raises + ------ + ValueError + If the backend is not supported. + + Examples + -------- + >>> from datajoint.adapters import get_adapter + >>> mysql_adapter = get_adapter('mysql') + >>> postgres_adapter = get_adapter('postgresql') + """ + backend_lower = backend.lower() + + if backend_lower not in ADAPTERS: + supported = sorted(set(ADAPTERS.keys())) + raise ValueError(f"Unknown database backend: {backend}. " f"Supported backends: {', '.join(supported)}") + + return ADAPTERS[backend_lower]() diff --git a/src/datajoint/adapters/base.py b/src/datajoint/adapters/base.py new file mode 100644 index 000000000..35b32ed5f --- /dev/null +++ b/src/datajoint/adapters/base.py @@ -0,0 +1,1169 @@ +""" +Abstract base class for database backend adapters. + +This module defines the interface that all database adapters must implement +to support multiple database backends (MySQL, PostgreSQL, etc.) in DataJoint. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + + +class DatabaseAdapter(ABC): + """ + Abstract base class for database backend adapters. + + Adapters provide database-specific implementations for SQL generation, + type mapping, error translation, and connection management. + """ + + # ========================================================================= + # Connection Management + # ========================================================================= + + @abstractmethod + def connect( + self, + host: str, + port: int, + user: str, + password: str, + **kwargs: Any, + ) -> Any: + """ + Establish database connection. + + Parameters + ---------- + host : str + Database server hostname. + port : int + Database server port. + user : str + Username for authentication. + password : str + Password for authentication. + **kwargs : Any + Additional backend-specific connection parameters. + + Returns + ------- + Any + Database connection object (backend-specific). + """ + ... + + @abstractmethod + def close(self, connection: Any) -> None: + """ + Close the database connection. + + Parameters + ---------- + connection : Any + Database connection object to close. + """ + ... + + @abstractmethod + def ping(self, connection: Any) -> bool: + """ + Check if connection is alive. + + Parameters + ---------- + connection : Any + Database connection object to check. + + Returns + ------- + bool + True if connection is alive, False otherwise. + """ + ... + + @abstractmethod + def get_connection_id(self, connection: Any) -> int: + """ + Get the current connection/backend process ID. + + Parameters + ---------- + connection : Any + Database connection object. + + Returns + ------- + int + Connection or process ID. + """ + ... + + @property + @abstractmethod + def default_port(self) -> int: + """ + Default port for this database backend. + + Returns + ------- + int + Default port number (3306 for MySQL, 5432 for PostgreSQL). + """ + ... + + @property + @abstractmethod + def backend(self) -> str: + """ + Backend identifier string. + + Returns + ------- + str + Backend name: 'mysql' or 'postgresql'. + """ + ... + + @abstractmethod + def get_cursor(self, connection: Any, as_dict: bool = False) -> Any: + """ + Get a cursor from the database connection. + + Parameters + ---------- + connection : Any + Database connection object. + as_dict : bool, optional + If True, return cursor that yields rows as dictionaries. + If False, return cursor that yields rows as tuples. + Default False. + + Returns + ------- + Any + Database cursor object (backend-specific). + """ + ... + + # ========================================================================= + # SQL Syntax + # ========================================================================= + + @abstractmethod + def quote_identifier(self, name: str) -> str: + """ + Quote an identifier (table/column name) for this backend. + + Parameters + ---------- + name : str + Identifier to quote. + + Returns + ------- + str + Quoted identifier (e.g., `name` for MySQL, "name" for PostgreSQL). + """ + ... + + @abstractmethod + def quote_string(self, value: str) -> str: + """ + Quote a string literal for this backend. + + Parameters + ---------- + value : str + String value to quote. + + Returns + ------- + str + Quoted string literal with proper escaping. + """ + ... + + @abstractmethod + def get_master_table_name(self, part_table: str) -> str | None: + """ + Extract master table name from a part table name. + + Parameters + ---------- + part_table : str + Full table name (e.g., `schema`.`master__part` for MySQL, + "schema"."master__part" for PostgreSQL). + + Returns + ------- + str or None + Master table name if part_table is a part table, None otherwise. + """ + ... + + @property + @abstractmethod + def parameter_placeholder(self) -> str: + """ + Parameter placeholder style for this backend. + + Returns + ------- + str + Placeholder string (e.g., '%s' for MySQL/psycopg2, '?' for SQLite). + """ + ... + + # ========================================================================= + # Type Mapping + # ========================================================================= + + @abstractmethod + def core_type_to_sql(self, core_type: str) -> str: + """ + Convert a DataJoint core type to backend SQL type. + + Parameters + ---------- + core_type : str + DataJoint core type (e.g., 'int64', 'float32', 'uuid'). + + Returns + ------- + str + Backend SQL type (e.g., 'bigint', 'float', 'binary(16)'). + + Raises + ------ + ValueError + If core_type is not a valid DataJoint core type. + """ + ... + + @abstractmethod + def sql_type_to_core(self, sql_type: str) -> str | None: + """ + Convert a backend SQL type to DataJoint core type (if mappable). + + Parameters + ---------- + sql_type : str + Backend SQL type. + + Returns + ------- + str or None + DataJoint core type if mappable, None otherwise. + """ + ... + + # ========================================================================= + # DDL Generation + # ========================================================================= + + @abstractmethod + def create_schema_sql(self, schema_name: str) -> str: + """ + Generate CREATE SCHEMA/DATABASE statement. + + Parameters + ---------- + schema_name : str + Name of schema/database to create. + + Returns + ------- + str + CREATE SCHEMA/DATABASE SQL statement. + """ + ... + + @abstractmethod + def drop_schema_sql(self, schema_name: str, if_exists: bool = True) -> str: + """ + Generate DROP SCHEMA/DATABASE statement. + + Parameters + ---------- + schema_name : str + Name of schema/database to drop. + if_exists : bool, optional + Include IF EXISTS clause. Default True. + + Returns + ------- + str + DROP SCHEMA/DATABASE SQL statement. + """ + ... + + @abstractmethod + def create_table_sql( + self, + table_name: str, + columns: list[dict[str, Any]], + primary_key: list[str], + foreign_keys: list[dict[str, Any]], + indexes: list[dict[str, Any]], + comment: str | None = None, + ) -> str: + """ + Generate CREATE TABLE statement. + + Parameters + ---------- + table_name : str + Name of table to create. + columns : list[dict] + Column definitions with keys: name, type, nullable, default, comment. + primary_key : list[str] + List of primary key column names. + foreign_keys : list[dict] + Foreign key definitions with keys: columns, ref_table, ref_columns. + indexes : list[dict] + Index definitions with keys: columns, unique. + comment : str, optional + Table comment. + + Returns + ------- + str + CREATE TABLE SQL statement. + """ + ... + + @abstractmethod + def drop_table_sql(self, table_name: str, if_exists: bool = True) -> str: + """ + Generate DROP TABLE statement. + + Parameters + ---------- + table_name : str + Name of table to drop. + if_exists : bool, optional + Include IF EXISTS clause. Default True. + + Returns + ------- + str + DROP TABLE SQL statement. + """ + ... + + @abstractmethod + def alter_table_sql( + self, + table_name: str, + add_columns: list[dict[str, Any]] | None = None, + drop_columns: list[str] | None = None, + modify_columns: list[dict[str, Any]] | None = None, + ) -> str: + """ + Generate ALTER TABLE statement. + + Parameters + ---------- + table_name : str + Name of table to alter. + add_columns : list[dict], optional + Columns to add with keys: name, type, nullable, default, comment. + drop_columns : list[str], optional + Column names to drop. + modify_columns : list[dict], optional + Columns to modify with keys: name, type, nullable, default, comment. + + Returns + ------- + str + ALTER TABLE SQL statement. + """ + ... + + @abstractmethod + def add_comment_sql( + self, + object_type: str, + object_name: str, + comment: str, + ) -> str | None: + """ + Generate comment statement (may be None if embedded in CREATE). + + Parameters + ---------- + object_type : str + Type of object ('table', 'column'). + object_name : str + Fully qualified object name. + comment : str + Comment text. + + Returns + ------- + str or None + COMMENT statement, or None if comments are inline in CREATE. + """ + ... + + # ========================================================================= + # DML Generation + # ========================================================================= + + @abstractmethod + def insert_sql( + self, + table_name: str, + columns: list[str], + on_duplicate: str | None = None, + ) -> str: + """ + Generate INSERT statement. + + Parameters + ---------- + table_name : str + Name of table to insert into. + columns : list[str] + Column names to insert. + on_duplicate : str, optional + Duplicate handling: 'ignore', 'replace', 'update', or None. + + Returns + ------- + str + INSERT SQL statement with parameter placeholders. + """ + ... + + @abstractmethod + def update_sql( + self, + table_name: str, + set_columns: list[str], + where_columns: list[str], + ) -> str: + """ + Generate UPDATE statement. + + Parameters + ---------- + table_name : str + Name of table to update. + set_columns : list[str] + Column names to set. + where_columns : list[str] + Column names for WHERE clause. + + Returns + ------- + str + UPDATE SQL statement with parameter placeholders. + """ + ... + + @abstractmethod + def delete_sql(self, table_name: str) -> str: + """ + Generate DELETE statement (WHERE clause added separately). + + Parameters + ---------- + table_name : str + Name of table to delete from. + + Returns + ------- + str + DELETE SQL statement without WHERE clause. + """ + ... + + @abstractmethod + def upsert_on_duplicate_sql( + self, + table_name: str, + columns: list[str], + primary_key: list[str], + num_rows: int, + ) -> str: + """ + Generate INSERT ... ON DUPLICATE KEY UPDATE (MySQL) or + INSERT ... ON CONFLICT ... DO UPDATE (PostgreSQL) statement. + + Parameters + ---------- + table_name : str + Fully qualified table name (with quotes). + columns : list[str] + Column names to insert (unquoted). + primary_key : list[str] + Primary key column names (unquoted) for conflict detection. + num_rows : int + Number of rows to insert (for generating placeholders). + + Returns + ------- + str + Upsert SQL statement with placeholders. + + Examples + -------- + MySQL: + INSERT INTO `table` (a, b, c) VALUES (%s, %s, %s), (%s, %s, %s) + ON DUPLICATE KEY UPDATE a = VALUES(a), b = VALUES(b), c = VALUES(c) + + PostgreSQL: + INSERT INTO "table" (a, b, c) VALUES (%s, %s, %s), (%s, %s, %s) + ON CONFLICT (a) DO UPDATE SET b = EXCLUDED.b, c = EXCLUDED.c + """ + ... + + @abstractmethod + def skip_duplicates_clause( + self, + full_table_name: str, + primary_key: list[str], + ) -> str: + """ + Generate clause to skip duplicate key insertions. + + For MySQL: ON DUPLICATE KEY UPDATE pk=table.pk (no-op update) + For PostgreSQL: ON CONFLICT (pk_cols) DO NOTHING + + Parameters + ---------- + full_table_name : str + Fully qualified table name (with quotes). + primary_key : list[str] + Primary key column names (unquoted). + + Returns + ------- + str + SQL clause to append to INSERT statement. + """ + ... + + @property + def supports_inline_indexes(self) -> bool: + """ + Whether this backend supports inline INDEX in CREATE TABLE. + + MySQL supports inline index definitions in CREATE TABLE. + PostgreSQL requires separate CREATE INDEX statements. + + Returns + ------- + bool + True for MySQL, False for PostgreSQL. + """ + return True # Default for MySQL, override in PostgreSQL + + def create_index_ddl( + self, + full_table_name: str, + columns: list[str], + unique: bool = False, + index_name: str | None = None, + ) -> str: + """ + Generate CREATE INDEX statement. + + Parameters + ---------- + full_table_name : str + Fully qualified table name (with quotes). + columns : list[str] + Column names to index (unquoted). + unique : bool, optional + If True, create a unique index. + index_name : str, optional + Custom index name. If None, auto-generate from table/columns. + + Returns + ------- + str + CREATE INDEX SQL statement. + """ + quoted_cols = ", ".join(self.quote_identifier(col) for col in columns) + # Generate index name from table and columns if not provided + if index_name is None: + # Extract table name from full_table_name for index naming + table_part = full_table_name.split(".")[-1].strip('`"') + col_part = "_".join(columns)[:30] # Truncate for long column lists + index_name = f"idx_{table_part}_{col_part}" + unique_clause = "UNIQUE " if unique else "" + return f"CREATE {unique_clause}INDEX {self.quote_identifier(index_name)} ON {full_table_name} ({quoted_cols})" + + # ========================================================================= + # Introspection + # ========================================================================= + + @abstractmethod + def list_schemas_sql(self) -> str: + """ + Generate query to list all schemas/databases. + + Returns + ------- + str + SQL query to list schemas. + """ + ... + + @abstractmethod + def list_tables_sql(self, schema_name: str, pattern: str | None = None) -> str: + """ + Generate query to list tables in a schema. + + Parameters + ---------- + schema_name : str + Name of schema to list tables from. + pattern : str, optional + LIKE pattern to filter table names. Use %% for % in SQL. + + Returns + ------- + str + SQL query to list tables. + """ + ... + + @abstractmethod + def get_table_info_sql(self, schema_name: str, table_name: str) -> str: + """ + Generate query to get table metadata (comment, engine, etc.). + + Parameters + ---------- + schema_name : str + Schema name. + table_name : str + Table name. + + Returns + ------- + str + SQL query to get table info. + """ + ... + + @abstractmethod + def get_columns_sql(self, schema_name: str, table_name: str) -> str: + """ + Generate query to get column definitions. + + Parameters + ---------- + schema_name : str + Schema name. + table_name : str + Table name. + + Returns + ------- + str + SQL query to get column definitions. + """ + ... + + @abstractmethod + def get_primary_key_sql(self, schema_name: str, table_name: str) -> str: + """ + Generate query to get primary key columns. + + Parameters + ---------- + schema_name : str + Schema name. + table_name : str + Table name. + + Returns + ------- + str + SQL query to get primary key columns. + """ + ... + + @abstractmethod + def get_foreign_keys_sql(self, schema_name: str, table_name: str) -> str: + """ + Generate query to get foreign key constraints. + + Parameters + ---------- + schema_name : str + Schema name. + table_name : str + Table name. + + Returns + ------- + str + SQL query to get foreign key constraints. + """ + ... + + @abstractmethod + def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_name: str) -> str: + """ + Generate query to get foreign key constraint details from information_schema. + + Used during cascade delete to determine FK columns when error message + doesn't provide full details. + + Parameters + ---------- + constraint_name : str + Name of the foreign key constraint. + schema_name : str + Schema/database name of the child table. + table_name : str + Name of the child table. + + Returns + ------- + str + SQL query that returns rows with columns: + - fk_attrs: foreign key column name in child table + - parent: parent table name (quoted, with schema) + - pk_attrs: referenced column name in parent table + """ + ... + + @abstractmethod + def parse_foreign_key_error(self, error_message: str) -> dict[str, str | list[str] | None] | None: + """ + Parse a foreign key violation error message to extract constraint details. + + Used during cascade delete to identify which child table is preventing + deletion and what columns are involved. + + Parameters + ---------- + error_message : str + The error message from a foreign key constraint violation. + + Returns + ------- + dict or None + Dictionary with keys if successfully parsed: + - child: child table name (quoted with schema if available) + - name: constraint name (quoted) + - fk_attrs: list of foreign key column names (may be None if not in message) + - parent: parent table name (quoted, may be None if not in message) + - pk_attrs: list of parent key column names (may be None if not in message) + + Returns None if error message doesn't match FK violation pattern. + + Examples + -------- + MySQL error: + "Cannot delete or update a parent row: a foreign key constraint fails + (`schema`.`child`, CONSTRAINT `fk_name` FOREIGN KEY (`child_col`) + REFERENCES `parent` (`parent_col`))" + + PostgreSQL error: + "update or delete on table \"parent\" violates foreign key constraint + \"child_parent_id_fkey\" on table \"child\" + DETAIL: Key (parent_id)=(1) is still referenced from table \"child\"." + """ + ... + + @abstractmethod + def get_indexes_sql(self, schema_name: str, table_name: str) -> str: + """ + Generate query to get index definitions. + + Parameters + ---------- + schema_name : str + Schema name. + table_name : str + Table name. + + Returns + ------- + str + SQL query to get index definitions. + """ + ... + + @abstractmethod + def parse_column_info(self, row: dict[str, Any]) -> dict[str, Any]: + """ + Parse a column info row into standardized format. + + Parameters + ---------- + row : dict + Raw column info row from database introspection query. + + Returns + ------- + dict + Standardized column info with keys: name, type, nullable, + default, comment, etc. + """ + ... + + # ========================================================================= + # Transactions + # ========================================================================= + + @abstractmethod + def start_transaction_sql(self, isolation_level: str | None = None) -> str: + """ + Generate START TRANSACTION statement. + + Parameters + ---------- + isolation_level : str, optional + Transaction isolation level. + + Returns + ------- + str + START TRANSACTION SQL statement. + """ + ... + + @abstractmethod + def commit_sql(self) -> str: + """ + Generate COMMIT statement. + + Returns + ------- + str + COMMIT SQL statement. + """ + ... + + @abstractmethod + def rollback_sql(self) -> str: + """ + Generate ROLLBACK statement. + + Returns + ------- + str + ROLLBACK SQL statement. + """ + ... + + # ========================================================================= + # Functions and Expressions + # ========================================================================= + + @abstractmethod + def current_timestamp_expr(self, precision: int | None = None) -> str: + """ + Expression for current timestamp. + + Parameters + ---------- + precision : int, optional + Fractional seconds precision (0-6). + + Returns + ------- + str + SQL expression for current timestamp. + """ + ... + + @abstractmethod + def interval_expr(self, value: int, unit: str) -> str: + """ + Expression for time interval. + + Parameters + ---------- + value : int + Interval value. + unit : str + Time unit ('second', 'minute', 'hour', 'day', etc.). + + Returns + ------- + str + SQL expression for interval (e.g., 'INTERVAL 5 SECOND' for MySQL, + "INTERVAL '5 seconds'" for PostgreSQL). + """ + ... + + @abstractmethod + def current_user_expr(self) -> str: + """ + SQL expression to get the current user. + + Returns + ------- + str + SQL expression for current user (e.g., 'user()' for MySQL, + 'current_user' for PostgreSQL). + """ + ... + + @abstractmethod + def json_path_expr(self, column: str, path: str, return_type: str | None = None) -> str: + """ + Generate JSON path extraction expression. + + Parameters + ---------- + column : str + Column name containing JSON data. + path : str + JSON path (e.g., 'field' or 'nested.field'). + return_type : str, optional + Return type specification (MySQL-specific). + + Returns + ------- + str + Database-specific JSON extraction SQL expression. + + Examples + -------- + MySQL: json_value(`column`, _utf8mb4'$.path' returning type) + PostgreSQL: jsonb_extract_path_text("column", 'path_part1', 'path_part2') + """ + ... + + def translate_expression(self, expr: str) -> str: + """ + Translate SQL expression for backend compatibility. + + Converts database-specific function calls to the equivalent syntax + for the current backend. This enables portable DataJoint code that + uses common aggregate functions. + + Translations performed: + - GROUP_CONCAT(col) ↔ STRING_AGG(col, ',') + + Parameters + ---------- + expr : str + SQL expression that may contain function calls. + + Returns + ------- + str + Translated expression for the current backend. + + Notes + ----- + The base implementation returns the expression unchanged. + Subclasses override to provide backend-specific translations. + """ + return expr + + # ========================================================================= + # DDL Generation + # ========================================================================= + + @abstractmethod + def format_column_definition( + self, + name: str, + sql_type: str, + nullable: bool = False, + default: str | None = None, + comment: str | None = None, + ) -> str: + """ + Format a column definition for DDL. + + Parameters + ---------- + name : str + Column name. + sql_type : str + SQL type (already backend-specific, e.g., 'bigint', 'varchar(255)'). + nullable : bool, optional + Whether column is nullable. Default False. + default : str | None, optional + Default value expression (e.g., 'NULL', '"value"', 'CURRENT_TIMESTAMP'). + comment : str | None, optional + Column comment. + + Returns + ------- + str + Formatted column definition (without trailing comma). + + Examples + -------- + MySQL: `name` bigint NOT NULL COMMENT "user ID" + PostgreSQL: "name" bigint NOT NULL + """ + ... + + @abstractmethod + def table_options_clause(self, comment: str | None = None) -> str: + """ + Generate table options clause (ENGINE, etc.) for CREATE TABLE. + + Parameters + ---------- + comment : str | None, optional + Table-level comment. + + Returns + ------- + str + Table options clause (e.g., 'ENGINE=InnoDB, COMMENT "..."' for MySQL). + + Examples + -------- + MySQL: ENGINE=InnoDB, COMMENT "experiment sessions" + PostgreSQL: (empty string, comments handled separately) + """ + ... + + @abstractmethod + def table_comment_ddl(self, full_table_name: str, comment: str) -> str | None: + """ + Generate DDL for table-level comment (if separate from CREATE TABLE). + + Parameters + ---------- + full_table_name : str + Fully qualified table name (quoted). + comment : str + Table comment. + + Returns + ------- + str or None + DDL statement for table comment, or None if handled inline. + + Examples + -------- + MySQL: None (inline) + PostgreSQL: COMMENT ON TABLE "schema"."table" IS 'comment text' + """ + ... + + @abstractmethod + def column_comment_ddl(self, full_table_name: str, column_name: str, comment: str) -> str | None: + """ + Generate DDL for column-level comment (if separate from CREATE TABLE). + + Parameters + ---------- + full_table_name : str + Fully qualified table name (quoted). + column_name : str + Column name (unquoted). + comment : str + Column comment. + + Returns + ------- + str or None + DDL statement for column comment, or None if handled inline. + + Examples + -------- + MySQL: None (inline) + PostgreSQL: COMMENT ON COLUMN "schema"."table"."column" IS 'comment text' + """ + ... + + @abstractmethod + def enum_type_ddl(self, type_name: str, values: list[str]) -> str | None: + """ + Generate DDL for enum type definition (if needed before CREATE TABLE). + + Parameters + ---------- + type_name : str + Enum type name. + values : list[str] + Enum values. + + Returns + ------- + str or None + DDL statement for enum type, or None if handled inline. + + Examples + -------- + MySQL: None (inline enum('val1', 'val2')) + PostgreSQL: CREATE TYPE "type_name" AS ENUM ('val1', 'val2') + """ + ... + + @abstractmethod + def job_metadata_columns(self) -> list[str]: + """ + Return job metadata column definitions for Computed/Imported tables. + + Returns + ------- + list[str] + List of column definition strings (fully formatted with quotes). + + Examples + -------- + MySQL: + ["`_job_start_time` datetime(3) DEFAULT NULL", + "`_job_duration` float DEFAULT NULL", + "`_job_version` varchar(64) DEFAULT ''"] + PostgreSQL: + ['"_job_start_time" timestamp DEFAULT NULL', + '"_job_duration" real DEFAULT NULL', + '"_job_version" varchar(64) DEFAULT \'\''] + """ + ... + + # ========================================================================= + # Error Translation + # ========================================================================= + + @abstractmethod + def translate_error(self, error: Exception, query: str = "") -> Exception: + """ + Translate backend-specific error to DataJoint error. + + Parameters + ---------- + error : Exception + Backend-specific exception. + + Returns + ------- + Exception + DataJoint exception or original error if no mapping exists. + """ + ... + + # ========================================================================= + # Native Type Validation + # ========================================================================= + + @abstractmethod + def validate_native_type(self, type_str: str) -> bool: + """ + Check if a native type string is valid for this backend. + + Parameters + ---------- + type_str : str + Native type string to validate. + + Returns + ------- + bool + True if valid for this backend, False otherwise. + """ + ... diff --git a/src/datajoint/adapters/mysql.py b/src/datajoint/adapters/mysql.py new file mode 100644 index 000000000..88339335f --- /dev/null +++ b/src/datajoint/adapters/mysql.py @@ -0,0 +1,1094 @@ +""" +MySQL database adapter for DataJoint. + +This module provides MySQL-specific implementations for SQL generation, +type mapping, error translation, and connection management. +""" + +from __future__ import annotations + +from typing import Any + +import pymysql as client + +from .. import errors +from .base import DatabaseAdapter + +# Core type mapping: DataJoint core types → MySQL types +CORE_TYPE_MAP = { + "int64": "bigint", + "int32": "int", + "int16": "smallint", + "int8": "tinyint", + "float32": "float", + "float64": "double", + "bool": "tinyint", + "uuid": "binary(16)", + "bytes": "longblob", + "json": "json", + "date": "date", + # datetime, char, varchar, decimal, enum require parameters - handled in method +} + +# Reverse mapping: MySQL types → DataJoint core types (for introspection) +SQL_TO_CORE_MAP = { + "bigint": "int64", + "int": "int32", + "smallint": "int16", + "tinyint": "int8", # Could be bool, need context + "float": "float32", + "double": "float64", + "binary(16)": "uuid", + "longblob": "bytes", + "json": "json", + "date": "date", +} + + +class MySQLAdapter(DatabaseAdapter): + """MySQL database adapter implementation.""" + + # ========================================================================= + # Connection Management + # ========================================================================= + + def connect( + self, + host: str, + port: int, + user: str, + password: str, + **kwargs: Any, + ) -> Any: + """ + Establish MySQL connection. + + Parameters + ---------- + host : str + MySQL server hostname. + port : int + MySQL server port. + user : str + Username for authentication. + password : str + Password for authentication. + **kwargs : Any + Additional MySQL-specific parameters: + - init_command: SQL initialization command + - ssl: TLS/SSL configuration dict (deprecated, use use_tls) + - use_tls: bool or dict - DataJoint's SSL parameter (preferred) + - charset: Character set (default from kwargs) + + Returns + ------- + pymysql.Connection + MySQL connection object. + """ + init_command = kwargs.get("init_command") + # Handle both ssl (old) and use_tls (new) parameter names + ssl_config = kwargs.get("use_tls", kwargs.get("ssl")) + # Convert boolean True to dict for PyMySQL (PyMySQL expects dict or SSLContext) + if ssl_config is True: + ssl_config = {} # Enable SSL with default settings + charset = kwargs.get("charset", "") + + # Prepare connection parameters + conn_params = { + "host": host, + "port": port, + "user": user, + "passwd": password, + "init_command": init_command, + "sql_mode": "NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO," + "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY", + "charset": charset, + "autocommit": True, # DataJoint manages transactions explicitly + } + + # Handle SSL configuration + if ssl_config is False: + # Explicitly disable SSL + conn_params["ssl_disabled"] = True + elif ssl_config is not None: + # Enable SSL with config dict (can be empty for defaults) + conn_params["ssl"] = ssl_config + # Explicitly enable SSL by setting ssl_disabled=False + conn_params["ssl_disabled"] = False + + return client.connect(**conn_params) + + def close(self, connection: Any) -> None: + """Close the MySQL connection.""" + connection.close() + + def ping(self, connection: Any) -> bool: + """ + Check if MySQL connection is alive. + + Returns + ------- + bool + True if connection is alive. + """ + try: + connection.ping(reconnect=False) + return True + except Exception: + return False + + def get_connection_id(self, connection: Any) -> int: + """ + Get MySQL connection ID. + + Returns + ------- + int + MySQL connection_id(). + """ + cursor = connection.cursor() + cursor.execute("SELECT connection_id()") + return cursor.fetchone()[0] + + @property + def default_port(self) -> int: + """MySQL default port 3306.""" + return 3306 + + @property + def backend(self) -> str: + """Backend identifier: 'mysql'.""" + return "mysql" + + def get_cursor(self, connection: Any, as_dict: bool = False) -> Any: + """ + Get a cursor from MySQL connection. + + Parameters + ---------- + connection : Any + pymysql connection object. + as_dict : bool, optional + If True, return DictCursor that yields rows as dictionaries. + If False, return standard Cursor that yields rows as tuples. + Default False. + + Returns + ------- + Any + pymysql cursor object. + """ + import pymysql + + cursor_class = pymysql.cursors.DictCursor if as_dict else pymysql.cursors.Cursor + return connection.cursor(cursor=cursor_class) + + # ========================================================================= + # SQL Syntax + # ========================================================================= + + def quote_identifier(self, name: str) -> str: + """ + Quote identifier with backticks for MySQL. + + Parameters + ---------- + name : str + Identifier to quote. + + Returns + ------- + str + Backtick-quoted identifier: `name` + """ + return f"`{name}`" + + def quote_string(self, value: str) -> str: + """ + Quote string literal for MySQL with escaping. + + Parameters + ---------- + value : str + String value to quote. + + Returns + ------- + str + Quoted and escaped string literal. + """ + # Use pymysql's escape_string for proper escaping + escaped = client.converters.escape_string(value) + return f"'{escaped}'" + + def get_master_table_name(self, part_table: str) -> str | None: + """Extract master table name from part table (MySQL backtick format).""" + import re + + # MySQL format: `schema`.`master__part` + match = re.match(r"(?P`\w+`.`#?\w+)__\w+`", part_table) + return match["master"] + "`" if match else None + + @property + def parameter_placeholder(self) -> str: + """MySQL/pymysql uses %s placeholders.""" + return "%s" + + # ========================================================================= + # Type Mapping + # ========================================================================= + + def core_type_to_sql(self, core_type: str) -> str: + """ + Convert DataJoint core type to MySQL type. + + Parameters + ---------- + core_type : str + DataJoint core type, possibly with parameters: + - int64, float32, bool, uuid, bytes, json, date + - datetime or datetime(n) + - char(n), varchar(n) + - decimal(p,s) + - enum('a','b','c') + + Returns + ------- + str + MySQL SQL type. + + Raises + ------ + ValueError + If core_type is not recognized. + """ + # Handle simple types without parameters + if core_type in CORE_TYPE_MAP: + return CORE_TYPE_MAP[core_type] + + # Handle parametrized types + if core_type.startswith("datetime"): + # datetime or datetime(precision) + return core_type # MySQL supports datetime(n) directly + + if core_type.startswith("char("): + # char(n) + return core_type + + if core_type.startswith("varchar("): + # varchar(n) + return core_type + + if core_type.startswith("decimal("): + # decimal(precision, scale) + return core_type + + if core_type.startswith("enum("): + # enum('value1', 'value2', ...) + return core_type + + raise ValueError(f"Unknown core type: {core_type}") + + def sql_type_to_core(self, sql_type: str) -> str | None: + """ + Convert MySQL type to DataJoint core type (if mappable). + + Parameters + ---------- + sql_type : str + MySQL SQL type. + + Returns + ------- + str or None + DataJoint core type if mappable, None otherwise. + """ + # Normalize type string (lowercase, strip spaces) + sql_type_lower = sql_type.lower().strip() + + # Direct mapping + if sql_type_lower in SQL_TO_CORE_MAP: + return SQL_TO_CORE_MAP[sql_type_lower] + + # Handle parametrized types + if sql_type_lower.startswith("datetime"): + return sql_type # Keep precision + + if sql_type_lower.startswith("char("): + return sql_type # Keep size + + if sql_type_lower.startswith("varchar("): + return sql_type # Keep size + + if sql_type_lower.startswith("decimal("): + return sql_type # Keep precision/scale + + if sql_type_lower.startswith("enum("): + return sql_type # Keep values + + # Not a mappable core type + return None + + # ========================================================================= + # DDL Generation + # ========================================================================= + + def create_schema_sql(self, schema_name: str) -> str: + """ + Generate CREATE DATABASE statement for MySQL. + + Parameters + ---------- + schema_name : str + Database name. + + Returns + ------- + str + CREATE DATABASE SQL. + """ + return f"CREATE DATABASE {self.quote_identifier(schema_name)}" + + def drop_schema_sql(self, schema_name: str, if_exists: bool = True) -> str: + """ + Generate DROP DATABASE statement for MySQL. + + Parameters + ---------- + schema_name : str + Database name. + if_exists : bool + Include IF EXISTS clause. + + Returns + ------- + str + DROP DATABASE SQL. + """ + if_exists_clause = "IF EXISTS " if if_exists else "" + return f"DROP DATABASE {if_exists_clause}{self.quote_identifier(schema_name)}" + + def create_table_sql( + self, + table_name: str, + columns: list[dict[str, Any]], + primary_key: list[str], + foreign_keys: list[dict[str, Any]], + indexes: list[dict[str, Any]], + comment: str | None = None, + ) -> str: + """ + Generate CREATE TABLE statement for MySQL. + + Parameters + ---------- + table_name : str + Fully qualified table name (schema.table). + columns : list[dict] + Column defs: [{name, type, nullable, default, comment}, ...] + primary_key : list[str] + Primary key column names. + foreign_keys : list[dict] + FK defs: [{columns, ref_table, ref_columns}, ...] + indexes : list[dict] + Index defs: [{columns, unique}, ...] + comment : str, optional + Table comment. + + Returns + ------- + str + CREATE TABLE SQL statement. + """ + lines = [] + + # Column definitions + for col in columns: + col_name = self.quote_identifier(col["name"]) + col_type = col["type"] + nullable = "NULL" if col.get("nullable", False) else "NOT NULL" + default = f" DEFAULT {col['default']}" if "default" in col else "" + col_comment = f" COMMENT {self.quote_string(col['comment'])}" if "comment" in col else "" + lines.append(f"{col_name} {col_type} {nullable}{default}{col_comment}") + + # Primary key + if primary_key: + pk_cols = ", ".join(self.quote_identifier(col) for col in primary_key) + lines.append(f"PRIMARY KEY ({pk_cols})") + + # Foreign keys + for fk in foreign_keys: + fk_cols = ", ".join(self.quote_identifier(col) for col in fk["columns"]) + ref_cols = ", ".join(self.quote_identifier(col) for col in fk["ref_columns"]) + lines.append( + f"FOREIGN KEY ({fk_cols}) REFERENCES {fk['ref_table']} ({ref_cols}) " f"ON UPDATE CASCADE ON DELETE RESTRICT" + ) + + # Indexes + for idx in indexes: + unique = "UNIQUE " if idx.get("unique", False) else "" + idx_cols = ", ".join(self.quote_identifier(col) for col in idx["columns"]) + lines.append(f"{unique}INDEX ({idx_cols})") + + # Assemble CREATE TABLE + table_def = ",\n ".join(lines) + comment_clause = f" COMMENT={self.quote_string(comment)}" if comment else "" + return f"CREATE TABLE IF NOT EXISTS {table_name} (\n {table_def}\n) ENGINE=InnoDB{comment_clause}" + + def drop_table_sql(self, table_name: str, if_exists: bool = True) -> str: + """Generate DROP TABLE statement for MySQL.""" + if_exists_clause = "IF EXISTS " if if_exists else "" + return f"DROP TABLE {if_exists_clause}{table_name}" + + def alter_table_sql( + self, + table_name: str, + add_columns: list[dict[str, Any]] | None = None, + drop_columns: list[str] | None = None, + modify_columns: list[dict[str, Any]] | None = None, + ) -> str: + """ + Generate ALTER TABLE statement for MySQL. + + Parameters + ---------- + table_name : str + Table name. + add_columns : list[dict], optional + Columns to add. + drop_columns : list[str], optional + Column names to drop. + modify_columns : list[dict], optional + Columns to modify. + + Returns + ------- + str + ALTER TABLE SQL statement. + """ + clauses = [] + + if add_columns: + for col in add_columns: + col_name = self.quote_identifier(col["name"]) + col_type = col["type"] + nullable = "NULL" if col.get("nullable", False) else "NOT NULL" + clauses.append(f"ADD {col_name} {col_type} {nullable}") + + if drop_columns: + for col_name in drop_columns: + clauses.append(f"DROP {self.quote_identifier(col_name)}") + + if modify_columns: + for col in modify_columns: + col_name = self.quote_identifier(col["name"]) + col_type = col["type"] + nullable = "NULL" if col.get("nullable", False) else "NOT NULL" + clauses.append(f"MODIFY {col_name} {col_type} {nullable}") + + return f"ALTER TABLE {table_name} {', '.join(clauses)}" + + def add_comment_sql( + self, + object_type: str, + object_name: str, + comment: str, + ) -> str | None: + """ + MySQL embeds comments in CREATE/ALTER, not separate statements. + + Returns None since comments are inline. + """ + return None + + # ========================================================================= + # DML Generation + # ========================================================================= + + def insert_sql( + self, + table_name: str, + columns: list[str], + on_duplicate: str | None = None, + ) -> str: + """ + Generate INSERT statement for MySQL. + + Parameters + ---------- + table_name : str + Table name. + columns : list[str] + Column names. + on_duplicate : str, optional + 'ignore', 'replace', or 'update'. + + Returns + ------- + str + INSERT SQL with placeholders. + """ + cols = ", ".join(self.quote_identifier(col) for col in columns) + placeholders = ", ".join([self.parameter_placeholder] * len(columns)) + + if on_duplicate == "ignore": + return f"INSERT IGNORE INTO {table_name} ({cols}) VALUES ({placeholders})" + elif on_duplicate == "replace": + return f"REPLACE INTO {table_name} ({cols}) VALUES ({placeholders})" + elif on_duplicate == "update": + # ON DUPLICATE KEY UPDATE col=VALUES(col) + updates = ", ".join(f"{self.quote_identifier(col)}=VALUES({self.quote_identifier(col)})" for col in columns) + return f"INSERT INTO {table_name} ({cols}) VALUES ({placeholders}) ON DUPLICATE KEY UPDATE {updates}" + else: + return f"INSERT INTO {table_name} ({cols}) VALUES ({placeholders})" + + def update_sql( + self, + table_name: str, + set_columns: list[str], + where_columns: list[str], + ) -> str: + """Generate UPDATE statement for MySQL.""" + set_clause = ", ".join(f"{self.quote_identifier(col)} = {self.parameter_placeholder}" for col in set_columns) + where_clause = " AND ".join(f"{self.quote_identifier(col)} = {self.parameter_placeholder}" for col in where_columns) + return f"UPDATE {table_name} SET {set_clause} WHERE {where_clause}" + + def delete_sql(self, table_name: str) -> str: + """Generate DELETE statement for MySQL (WHERE added separately).""" + return f"DELETE FROM {table_name}" + + def upsert_on_duplicate_sql( + self, + table_name: str, + columns: list[str], + primary_key: list[str], + num_rows: int, + ) -> str: + """Generate INSERT ... ON DUPLICATE KEY UPDATE statement for MySQL.""" + # Build column list + col_list = ", ".join(columns) + + # Build placeholders for VALUES + placeholders = ", ".join(["(%s)" % ", ".join(["%s"] * len(columns))] * num_rows) + + # Build UPDATE clause (all columns) + update_clauses = ", ".join(f"{col} = VALUES({col})" for col in columns) + + return f""" + INSERT INTO {table_name} ({col_list}) + VALUES {placeholders} + ON DUPLICATE KEY UPDATE {update_clauses} + """ + + def skip_duplicates_clause( + self, + full_table_name: str, + primary_key: list[str], + ) -> str: + """ + Generate clause to skip duplicate key insertions for MySQL. + + Uses ON DUPLICATE KEY UPDATE with a no-op update (pk=pk) to effectively + skip duplicates without raising an error. + + Parameters + ---------- + full_table_name : str + Fully qualified table name (with quotes). + primary_key : list[str] + Primary key column names (unquoted). + + Returns + ------- + str + MySQL ON DUPLICATE KEY UPDATE clause. + """ + quoted_pk = self.quote_identifier(primary_key[0]) + return f" ON DUPLICATE KEY UPDATE {quoted_pk}={full_table_name}.{quoted_pk}" + + # ========================================================================= + # Introspection + # ========================================================================= + + def list_schemas_sql(self) -> str: + """Query to list all databases in MySQL.""" + return "SELECT schema_name FROM information_schema.schemata" + + def list_tables_sql(self, schema_name: str, pattern: str | None = None) -> str: + """Query to list tables in a database.""" + sql = f"SHOW TABLES IN {self.quote_identifier(schema_name)}" + if pattern: + sql += f" LIKE '{pattern}'" + return sql + + def get_table_info_sql(self, schema_name: str, table_name: str) -> str: + """Query to get table metadata (comment, engine, etc.).""" + return ( + f"SELECT * FROM information_schema.tables " + f"WHERE table_schema = {self.quote_string(schema_name)} " + f"AND table_name = {self.quote_string(table_name)}" + ) + + def get_columns_sql(self, schema_name: str, table_name: str) -> str: + """Query to get column definitions.""" + return f"SHOW FULL COLUMNS FROM {self.quote_identifier(table_name)} IN {self.quote_identifier(schema_name)}" + + def get_primary_key_sql(self, schema_name: str, table_name: str) -> str: + """Query to get primary key columns.""" + return ( + f"SELECT COLUMN_NAME as column_name FROM information_schema.key_column_usage " + f"WHERE table_schema = {self.quote_string(schema_name)} " + f"AND table_name = {self.quote_string(table_name)} " + f"AND constraint_name = 'PRIMARY' " + f"ORDER BY ordinal_position" + ) + + def get_foreign_keys_sql(self, schema_name: str, table_name: str) -> str: + """Query to get foreign key constraints.""" + return ( + f"SELECT CONSTRAINT_NAME as constraint_name, COLUMN_NAME as column_name, " + f"REFERENCED_TABLE_NAME as referenced_table_name, REFERENCED_COLUMN_NAME as referenced_column_name " + f"FROM information_schema.key_column_usage " + f"WHERE table_schema = {self.quote_string(schema_name)} " + f"AND table_name = {self.quote_string(table_name)} " + f"AND referenced_table_name IS NOT NULL " + f"ORDER BY constraint_name, ordinal_position" + ) + + def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_name: str) -> str: + """Query to get FK constraint details from information_schema.""" + return ( + "SELECT " + " COLUMN_NAME as fk_attrs, " + " CONCAT('`', REFERENCED_TABLE_SCHEMA, '`.`', REFERENCED_TABLE_NAME, '`') as parent, " + " REFERENCED_COLUMN_NAME as pk_attrs " + "FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE " + "WHERE CONSTRAINT_NAME = %s AND TABLE_SCHEMA = %s AND TABLE_NAME = %s" + ) + + def parse_foreign_key_error(self, error_message: str) -> dict[str, str | list[str] | None] | None: + """Parse MySQL foreign key violation error message.""" + import re + + # MySQL FK error pattern with backticks + pattern = re.compile( + r"[\w\s:]*\((?P`[^`]+`.`[^`]+`), " + r"CONSTRAINT (?P`[^`]+`) " + r"(FOREIGN KEY \((?P[^)]+)\) " + r"REFERENCES (?P`[^`]+`(\.`[^`]+`)?) \((?P[^)]+)\)[\s\w]+\))?" + ) + + match = pattern.match(error_message) + if not match: + return None + + result = match.groupdict() + + # Parse comma-separated FK attrs if present + if result.get("fk_attrs"): + result["fk_attrs"] = [col.strip("`") for col in result["fk_attrs"].split(",")] + # Parse comma-separated PK attrs if present + if result.get("pk_attrs"): + result["pk_attrs"] = [col.strip("`") for col in result["pk_attrs"].split(",")] + + return result + + def get_indexes_sql(self, schema_name: str, table_name: str) -> str: + """Query to get index definitions. + + Note: For MySQL 8.0+, EXPRESSION column contains the expression for + functional indexes. COLUMN_NAME is NULL for such indexes. + """ + return ( + f"SELECT INDEX_NAME as index_name, " + f"COALESCE(COLUMN_NAME, CONCAT('(', EXPRESSION, ')')) as column_name, " + f"NON_UNIQUE as non_unique, SEQ_IN_INDEX as seq_in_index " + f"FROM information_schema.statistics " + f"WHERE table_schema = {self.quote_string(schema_name)} " + f"AND table_name = {self.quote_string(table_name)} " + f"AND index_name != 'PRIMARY' " + f"ORDER BY index_name, seq_in_index" + ) + + def parse_column_info(self, row: dict[str, Any]) -> dict[str, Any]: + """ + Parse MySQL SHOW FULL COLUMNS output into standardized format. + + Parameters + ---------- + row : dict + Row from SHOW FULL COLUMNS query. + + Returns + ------- + dict + Standardized column info with keys: + name, type, nullable, default, comment, key, extra + """ + return { + "name": row["Field"], + "type": row["Type"], + "nullable": row["Null"] == "YES", + "default": row["Default"], + "comment": row["Comment"], + "key": row["Key"], # PRI, UNI, MUL + "extra": row["Extra"], # auto_increment, etc. + } + + # ========================================================================= + # Transactions + # ========================================================================= + + def start_transaction_sql(self, isolation_level: str | None = None) -> str: + """Generate START TRANSACTION statement.""" + if isolation_level: + return f"START TRANSACTION WITH CONSISTENT SNAPSHOT, {isolation_level}" + return "START TRANSACTION WITH CONSISTENT SNAPSHOT" + + def commit_sql(self) -> str: + """Generate COMMIT statement.""" + return "COMMIT" + + def rollback_sql(self) -> str: + """Generate ROLLBACK statement.""" + return "ROLLBACK" + + # ========================================================================= + # Functions and Expressions + # ========================================================================= + + def current_timestamp_expr(self, precision: int | None = None) -> str: + """ + CURRENT_TIMESTAMP expression for MySQL. + + Parameters + ---------- + precision : int, optional + Fractional seconds precision (0-6). + + Returns + ------- + str + CURRENT_TIMESTAMP or CURRENT_TIMESTAMP(n). + """ + if precision is not None: + return f"CURRENT_TIMESTAMP({precision})" + return "CURRENT_TIMESTAMP" + + def interval_expr(self, value: int, unit: str) -> str: + """ + INTERVAL expression for MySQL. + + Parameters + ---------- + value : int + Interval value. + unit : str + Time unit (singular: 'second', 'minute', 'hour', 'day'). + + Returns + ------- + str + INTERVAL n UNIT (e.g., 'INTERVAL 5 SECOND'). + """ + # MySQL uses singular unit names + return f"INTERVAL {value} {unit.upper()}" + + def current_user_expr(self) -> str: + """MySQL current user expression.""" + return "user()" + + def json_path_expr(self, column: str, path: str, return_type: str | None = None) -> str: + """ + Generate MySQL json_value() expression. + + Parameters + ---------- + column : str + Column name containing JSON data. + path : str + JSON path (e.g., 'field' or 'nested.field'). + return_type : str, optional + Return type specification (e.g., 'decimal(10,2)'). + + Returns + ------- + str + MySQL json_value() expression. + + Examples + -------- + >>> adapter.json_path_expr('data', 'field') + "json_value(`data`, _utf8mb4'$.field')" + >>> adapter.json_path_expr('data', 'value', 'decimal(10,2)') + "json_value(`data`, _utf8mb4'$.value' returning decimal(10,2))" + """ + quoted_col = self.quote_identifier(column) + return_clause = f" returning {return_type}" if return_type else "" + return f"json_value({quoted_col}, _utf8mb4'$.{path}'{return_clause})" + + def translate_expression(self, expr: str) -> str: + """ + Translate SQL expression for MySQL compatibility. + + Converts PostgreSQL-specific functions to MySQL equivalents: + - STRING_AGG(col, 'sep') → GROUP_CONCAT(col SEPARATOR 'sep') + - STRING_AGG(col, ',') → GROUP_CONCAT(col) + + Parameters + ---------- + expr : str + SQL expression that may contain function calls. + + Returns + ------- + str + Translated expression for MySQL. + """ + import re + + # STRING_AGG(col, 'sep') → GROUP_CONCAT(col SEPARATOR 'sep') + def replace_string_agg(match): + inner = match.group(1).strip() + # Parse arguments: col, 'separator' + # Handle both single and double quoted separators + arg_match = re.match(r"(.+?)\s*,\s*(['\"])(.+?)\2", inner) + if arg_match: + col = arg_match.group(1).strip() + sep = arg_match.group(3) + # Remove ::text cast if present (PostgreSQL-specific) + col = re.sub(r"::text$", "", col) + if sep == ",": + return f"GROUP_CONCAT({col})" + else: + return f"GROUP_CONCAT({col} SEPARATOR '{sep}')" + else: + # No separator found, just use the expression + col = re.sub(r"::text$", "", inner) + return f"GROUP_CONCAT({col})" + + expr = re.sub(r"STRING_AGG\s*\((.+?)\)", replace_string_agg, expr, flags=re.IGNORECASE) + + return expr + + # ========================================================================= + # DDL Generation + # ========================================================================= + + def format_column_definition( + self, + name: str, + sql_type: str, + nullable: bool = False, + default: str | None = None, + comment: str | None = None, + ) -> str: + """ + Format a column definition for MySQL DDL. + + Examples + -------- + >>> adapter.format_column_definition('user_id', 'bigint', nullable=False, comment='user ID') + "`user_id` bigint NOT NULL COMMENT \\"user ID\\"" + """ + parts = [self.quote_identifier(name), sql_type] + if default: + parts.append(default) # e.g., "DEFAULT NULL" or "NOT NULL DEFAULT 5" + elif not nullable: + parts.append("NOT NULL") + if comment: + parts.append(f'COMMENT "{comment}"') + return " ".join(parts) + + def table_options_clause(self, comment: str | None = None) -> str: + """ + Generate MySQL table options clause. + + Examples + -------- + >>> adapter.table_options_clause('test table') + 'ENGINE=InnoDB, COMMENT "test table"' + >>> adapter.table_options_clause() + 'ENGINE=InnoDB' + """ + clause = "ENGINE=InnoDB" + if comment: + clause += f', COMMENT "{comment}"' + return clause + + def table_comment_ddl(self, full_table_name: str, comment: str) -> str | None: + """ + MySQL uses inline COMMENT in CREATE TABLE, so no separate DDL needed. + + Examples + -------- + >>> adapter.table_comment_ddl('`schema`.`table`', 'test comment') + None + """ + return None # MySQL uses inline COMMENT + + def column_comment_ddl(self, full_table_name: str, column_name: str, comment: str) -> str | None: + """ + MySQL uses inline COMMENT in column definitions, so no separate DDL needed. + + Examples + -------- + >>> adapter.column_comment_ddl('`schema`.`table`', 'column', 'test comment') + None + """ + return None # MySQL uses inline COMMENT + + def enum_type_ddl(self, type_name: str, values: list[str]) -> str | None: + """ + MySQL uses inline enum type in column definition, so no separate DDL needed. + + Examples + -------- + >>> adapter.enum_type_ddl('status_type', ['active', 'inactive']) + None + """ + return None # MySQL uses inline enum + + def job_metadata_columns(self) -> list[str]: + """ + Return MySQL-specific job metadata column definitions. + + Examples + -------- + >>> adapter.job_metadata_columns() + ["`_job_start_time` datetime(3) DEFAULT NULL", + "`_job_duration` float DEFAULT NULL", + "`_job_version` varchar(64) DEFAULT ''"] + """ + return [ + "`_job_start_time` datetime(3) DEFAULT NULL", + "`_job_duration` float DEFAULT NULL", + "`_job_version` varchar(64) DEFAULT ''", + ] + + # ========================================================================= + # Error Translation + # ========================================================================= + + def translate_error(self, error: Exception, query: str = "") -> Exception: + """ + Translate MySQL error to DataJoint exception. + + Parameters + ---------- + error : Exception + MySQL exception (typically pymysql error). + + Returns + ------- + Exception + DataJoint exception or original error. + """ + if not hasattr(error, "args") or len(error.args) == 0: + return error + + err, *args = error.args + + match err: + # Loss of connection errors + case 0 | "(0, '')": + return errors.LostConnectionError("Server connection lost due to an interface error.", *args) + case 2006: + return errors.LostConnectionError("Connection timed out", *args) + case 2013: + return errors.LostConnectionError("Server connection lost", *args) + + # Access errors + case 1044 | 1142: + query = args[0] if args else "" + return errors.AccessError("Insufficient privileges.", args[0] if args else "", query) + + # Integrity errors + case 1062: + return errors.DuplicateError(*args) + case 1217 | 1451 | 1452 | 3730: + return errors.IntegrityError(*args) + + # Syntax errors + case 1064: + query = args[0] if args else "" + return errors.QuerySyntaxError(args[0] if args else "", query) + + # Existence errors + case 1146: + query = args[0] if args else "" + return errors.MissingTableError(args[0] if args else "", query) + case 1364: + return errors.MissingAttributeError(*args) + case 1054: + return errors.UnknownAttributeError(*args) + + # All other errors pass through unchanged + case _: + return error + + # ========================================================================= + # Native Type Validation + # ========================================================================= + + def validate_native_type(self, type_str: str) -> bool: + """ + Check if a native MySQL type string is valid. + + Parameters + ---------- + type_str : str + Type string to validate. + + Returns + ------- + bool + True if valid MySQL type. + """ + type_lower = type_str.lower().strip() + + # MySQL native types (simplified validation) + valid_types = { + # Integer types + "tinyint", + "smallint", + "mediumint", + "int", + "integer", + "bigint", + # Floating point + "float", + "double", + "real", + "decimal", + "numeric", + # String types + "char", + "varchar", + "binary", + "varbinary", + "tinyblob", + "blob", + "mediumblob", + "longblob", + "tinytext", + "text", + "mediumtext", + "longtext", + # Temporal types + "date", + "time", + "datetime", + "timestamp", + "year", + # Other + "enum", + "set", + "json", + "geometry", + } + + # Extract base type (before parentheses) + base_type = type_lower.split("(")[0].strip() + + return base_type in valid_types diff --git a/src/datajoint/adapters/postgres.py b/src/datajoint/adapters/postgres.py new file mode 100644 index 000000000..12fecae6a --- /dev/null +++ b/src/datajoint/adapters/postgres.py @@ -0,0 +1,1510 @@ +""" +PostgreSQL database adapter for DataJoint. + +This module provides PostgreSQL-specific implementations for SQL generation, +type mapping, error translation, and connection management. +""" + +from __future__ import annotations + +import re +from typing import Any + +try: + import psycopg2 as client + from psycopg2 import sql +except ImportError: + client = None # type: ignore + sql = None # type: ignore + +from .. import errors +from .base import DatabaseAdapter + +# Core type mapping: DataJoint core types → PostgreSQL types +CORE_TYPE_MAP = { + "int64": "bigint", + "int32": "integer", + "int16": "smallint", + "int8": "smallint", # PostgreSQL lacks tinyint; semantically equivalent + "float32": "real", + "float64": "double precision", + "bool": "boolean", + "uuid": "uuid", # Native UUID support + "bytes": "bytea", + "json": "jsonb", # Using jsonb for better performance + "date": "date", + # datetime, char, varchar, decimal, enum require parameters - handled in method +} + +# Reverse mapping: PostgreSQL types → DataJoint core types (for introspection) +SQL_TO_CORE_MAP = { + "bigint": "int64", + "integer": "int32", + "smallint": "int16", + "real": "float32", + "double precision": "float64", + "boolean": "bool", + "uuid": "uuid", + "bytea": "bytes", + "jsonb": "json", + "json": "json", + "date": "date", +} + + +class PostgreSQLAdapter(DatabaseAdapter): + """PostgreSQL database adapter implementation.""" + + def __init__(self) -> None: + """Initialize PostgreSQL adapter.""" + if client is None: + raise ImportError( + "psycopg2 is required for PostgreSQL support. " "Install it with: pip install 'datajoint[postgres]'" + ) + + # ========================================================================= + # Connection Management + # ========================================================================= + + def connect( + self, + host: str, + port: int, + user: str, + password: str, + **kwargs: Any, + ) -> Any: + """ + Establish PostgreSQL connection. + + Parameters + ---------- + host : str + PostgreSQL server hostname. + port : int + PostgreSQL server port. + user : str + Username for authentication. + password : str + Password for authentication. + **kwargs : Any + Additional PostgreSQL-specific parameters: + - dbname: Database name + - sslmode: SSL mode ('disable', 'allow', 'prefer', 'require') + - use_tls: bool or dict - DataJoint's SSL parameter (converted to sslmode) + - connect_timeout: Connection timeout in seconds + + Returns + ------- + psycopg2.connection + PostgreSQL connection object. + """ + dbname = kwargs.get("dbname", "postgres") # Default to postgres database + connect_timeout = kwargs.get("connect_timeout", 10) + + # Handle use_tls parameter (from DataJoint Connection) + # Convert to PostgreSQL's sslmode + use_tls = kwargs.get("use_tls") + if "sslmode" in kwargs: + # Explicit sslmode takes precedence + sslmode = kwargs["sslmode"] + elif use_tls is False: + # use_tls=False → disable SSL + sslmode = "disable" + elif use_tls is True or isinstance(use_tls, dict): + # use_tls=True or dict → require SSL + sslmode = "require" + else: + # use_tls=None (default) → prefer SSL but allow fallback + sslmode = "prefer" + + conn = client.connect( + host=host, + port=port, + user=user, + password=password, + dbname=dbname, + sslmode=sslmode, + connect_timeout=connect_timeout, + ) + # DataJoint manages transactions explicitly via start_transaction() + # Set autocommit=True to avoid implicit transactions + conn.autocommit = True + + # Register numpy type adapters so numpy types can be used directly in queries + self._register_numpy_adapters() + + return conn + + def _register_numpy_adapters(self) -> None: + """ + Register psycopg2 adapters for numpy types. + + This allows numpy scalar types (bool_, int64, float64, etc.) to be used + directly in queries without explicit conversion to Python native types. + """ + try: + import numpy as np + from psycopg2.extensions import register_adapter, AsIs + + # Numpy bool type + register_adapter(np.bool_, lambda x: AsIs(str(bool(x)).upper())) + + # Numpy integer types + for np_type in (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64): + register_adapter(np_type, lambda x: AsIs(int(x))) + + # Numpy float types + for np_ftype in (np.float16, np.float32, np.float64): + register_adapter(np_ftype, lambda x: AsIs(repr(float(x)))) + + except ImportError: + pass # numpy not available + + def close(self, connection: Any) -> None: + """Close the PostgreSQL connection.""" + connection.close() + + def ping(self, connection: Any) -> bool: + """ + Check if PostgreSQL connection is alive. + + Returns + ------- + bool + True if connection is alive. + """ + try: + cursor = connection.cursor() + cursor.execute("SELECT 1") + cursor.close() + return True + except Exception: + return False + + def get_connection_id(self, connection: Any) -> int: + """ + Get PostgreSQL backend process ID. + + Returns + ------- + int + PostgreSQL pg_backend_pid(). + """ + cursor = connection.cursor() + cursor.execute("SELECT pg_backend_pid()") + return cursor.fetchone()[0] + + @property + def default_port(self) -> int: + """PostgreSQL default port 5432.""" + return 5432 + + @property + def backend(self) -> str: + """Backend identifier: 'postgresql'.""" + return "postgresql" + + def get_cursor(self, connection: Any, as_dict: bool = False) -> Any: + """ + Get a cursor from PostgreSQL connection. + + Parameters + ---------- + connection : Any + psycopg2 connection object. + as_dict : bool, optional + If True, return Real DictCursor that yields rows as dictionaries. + If False, return standard cursor that yields rows as tuples. + Default False. + + Returns + ------- + Any + psycopg2 cursor object. + """ + import psycopg2.extras + + if as_dict: + return connection.cursor(cursor_factory=psycopg2.extras.RealDictCursor) + return connection.cursor() + + # ========================================================================= + # SQL Syntax + # ========================================================================= + + def quote_identifier(self, name: str) -> str: + """ + Quote identifier with double quotes for PostgreSQL. + + Parameters + ---------- + name : str + Identifier to quote. + + Returns + ------- + str + Double-quoted identifier: "name" + """ + return f'"{name}"' + + def quote_string(self, value: str) -> str: + """ + Quote string literal for PostgreSQL with escaping. + + Parameters + ---------- + value : str + String value to quote. + + Returns + ------- + str + Quoted and escaped string literal. + """ + # Escape single quotes by doubling them (PostgreSQL standard) + escaped = value.replace("'", "''") + return f"'{escaped}'" + + def get_master_table_name(self, part_table: str) -> str | None: + """Extract master table name from part table (PostgreSQL double-quote format).""" + import re + + # PostgreSQL format: "schema"."master__part" + match = re.match(r'(?P"\w+"."#?\w+)__\w+"', part_table) + return match["master"] + '"' if match else None + + @property + def parameter_placeholder(self) -> str: + """PostgreSQL/psycopg2 uses %s placeholders.""" + return "%s" + + # ========================================================================= + # Type Mapping + # ========================================================================= + + def core_type_to_sql(self, core_type: str) -> str: + """ + Convert DataJoint core type to PostgreSQL type. + + Parameters + ---------- + core_type : str + DataJoint core type, possibly with parameters: + - int64, float32, bool, uuid, bytes, json, date + - datetime or datetime(n) → timestamp(n) + - char(n), varchar(n) + - decimal(p,s) → numeric(p,s) + - enum('a','b','c') → requires CREATE TYPE + + Returns + ------- + str + PostgreSQL SQL type. + + Raises + ------ + ValueError + If core_type is not recognized. + """ + # Handle simple types without parameters + if core_type in CORE_TYPE_MAP: + return CORE_TYPE_MAP[core_type] + + # Handle parametrized types + if core_type.startswith("datetime"): + # datetime or datetime(precision) → timestamp or timestamp(precision) + if "(" in core_type: + # Extract precision: datetime(3) → timestamp(3) + precision = core_type[core_type.index("(") : core_type.index(")") + 1] + return f"timestamp{precision}" + return "timestamp" + + if core_type.startswith("char("): + # char(n) + return core_type + + if core_type.startswith("varchar("): + # varchar(n) + return core_type + + if core_type.startswith("decimal("): + # decimal(precision, scale) → numeric(precision, scale) + params = core_type[7:] # Remove "decimal" + return f"numeric{params}" + + if core_type.startswith("enum("): + # PostgreSQL requires CREATE TYPE for enums + # Extract enum values and generate a deterministic type name + enum_match = re.match(r"enum\s*\((.+)\)", core_type, re.I) + if enum_match: + # Parse enum values: enum('M','F') -> ['M', 'F'] + values_str = enum_match.group(1) + # Split by comma, handling quoted values + values = [v.strip().strip("'\"") for v in values_str.split(",")] + # Generate a deterministic type name based on values + # Use a hash to keep name reasonable length + import hashlib + + value_hash = hashlib.md5("_".join(sorted(values)).encode()).hexdigest()[:8] + type_name = f"enum_{value_hash}" + # Track this enum type for CREATE TYPE DDL + if not hasattr(self, "_pending_enum_types"): + self._pending_enum_types = {} + self._pending_enum_types[type_name] = values + # Return schema-qualified type reference using placeholder + # {database} will be replaced with actual schema name in table.py + return '"{database}".' + self.quote_identifier(type_name) + return "text" # Fallback if parsing fails + + raise ValueError(f"Unknown core type: {core_type}") + + def sql_type_to_core(self, sql_type: str) -> str | None: + """ + Convert PostgreSQL type to DataJoint core type (if mappable). + + Parameters + ---------- + sql_type : str + PostgreSQL SQL type. + + Returns + ------- + str or None + DataJoint core type if mappable, None otherwise. + """ + # Normalize type string (lowercase, strip spaces) + sql_type_lower = sql_type.lower().strip() + + # Direct mapping + if sql_type_lower in SQL_TO_CORE_MAP: + return SQL_TO_CORE_MAP[sql_type_lower] + + # Handle parametrized types + if sql_type_lower.startswith("timestamp"): + # timestamp(n) → datetime(n) + if "(" in sql_type_lower: + precision = sql_type_lower[sql_type_lower.index("(") : sql_type_lower.index(")") + 1] + return f"datetime{precision}" + return "datetime" + + if sql_type_lower.startswith("char("): + return sql_type # Keep size + + if sql_type_lower.startswith("varchar("): + return sql_type # Keep size + + if sql_type_lower.startswith("numeric("): + # numeric(p,s) → decimal(p,s) + params = sql_type_lower[7:] # Remove "numeric" + return f"decimal{params}" + + # Not a mappable core type + return None + + # ========================================================================= + # DDL Generation + # ========================================================================= + + def create_schema_sql(self, schema_name: str) -> str: + """ + Generate CREATE SCHEMA statement for PostgreSQL. + + Parameters + ---------- + schema_name : str + Schema name. + + Returns + ------- + str + CREATE SCHEMA SQL. + """ + return f"CREATE SCHEMA {self.quote_identifier(schema_name)}" + + def drop_schema_sql(self, schema_name: str, if_exists: bool = True) -> str: + """ + Generate DROP SCHEMA statement for PostgreSQL. + + Parameters + ---------- + schema_name : str + Schema name. + if_exists : bool + Include IF EXISTS clause. + + Returns + ------- + str + DROP SCHEMA SQL. + """ + if_exists_clause = "IF EXISTS " if if_exists else "" + return f"DROP SCHEMA {if_exists_clause}{self.quote_identifier(schema_name)} CASCADE" + + def create_table_sql( + self, + table_name: str, + columns: list[dict[str, Any]], + primary_key: list[str], + foreign_keys: list[dict[str, Any]], + indexes: list[dict[str, Any]], + comment: str | None = None, + ) -> str: + """ + Generate CREATE TABLE statement for PostgreSQL. + + Parameters + ---------- + table_name : str + Fully qualified table name (schema.table). + columns : list[dict] + Column defs: [{name, type, nullable, default, comment}, ...] + primary_key : list[str] + Primary key column names. + foreign_keys : list[dict] + FK defs: [{columns, ref_table, ref_columns}, ...] + indexes : list[dict] + Index defs: [{columns, unique}, ...] + comment : str, optional + Table comment (added via separate COMMENT ON statement). + + Returns + ------- + str + CREATE TABLE SQL statement (comments via separate COMMENT ON). + """ + lines = [] + + # Column definitions + for col in columns: + col_name = self.quote_identifier(col["name"]) + col_type = col["type"] + nullable = "NULL" if col.get("nullable", False) else "NOT NULL" + default = f" DEFAULT {col['default']}" if "default" in col else "" + # PostgreSQL comments are via COMMENT ON, not inline + lines.append(f"{col_name} {col_type} {nullable}{default}") + + # Primary key + if primary_key: + pk_cols = ", ".join(self.quote_identifier(col) for col in primary_key) + lines.append(f"PRIMARY KEY ({pk_cols})") + + # Foreign keys + for fk in foreign_keys: + fk_cols = ", ".join(self.quote_identifier(col) for col in fk["columns"]) + ref_cols = ", ".join(self.quote_identifier(col) for col in fk["ref_columns"]) + lines.append( + f"FOREIGN KEY ({fk_cols}) REFERENCES {fk['ref_table']} ({ref_cols}) " f"ON UPDATE CASCADE ON DELETE RESTRICT" + ) + + # Indexes - PostgreSQL creates indexes separately via CREATE INDEX + # (handled by caller after table creation) + + # Assemble CREATE TABLE (no ENGINE in PostgreSQL) + table_def = ",\n ".join(lines) + return f"CREATE TABLE IF NOT EXISTS {table_name} (\n {table_def}\n)" + + def drop_table_sql(self, table_name: str, if_exists: bool = True) -> str: + """Generate DROP TABLE statement for PostgreSQL.""" + if_exists_clause = "IF EXISTS " if if_exists else "" + return f"DROP TABLE {if_exists_clause}{table_name} CASCADE" + + def alter_table_sql( + self, + table_name: str, + add_columns: list[dict[str, Any]] | None = None, + drop_columns: list[str] | None = None, + modify_columns: list[dict[str, Any]] | None = None, + ) -> str: + """ + Generate ALTER TABLE statement for PostgreSQL. + + Parameters + ---------- + table_name : str + Table name. + add_columns : list[dict], optional + Columns to add. + drop_columns : list[str], optional + Column names to drop. + modify_columns : list[dict], optional + Columns to modify. + + Returns + ------- + str + ALTER TABLE SQL statement. + """ + clauses = [] + + if add_columns: + for col in add_columns: + col_name = self.quote_identifier(col["name"]) + col_type = col["type"] + nullable = "NULL" if col.get("nullable", False) else "NOT NULL" + clauses.append(f"ADD COLUMN {col_name} {col_type} {nullable}") + + if drop_columns: + for col_name in drop_columns: + clauses.append(f"DROP COLUMN {self.quote_identifier(col_name)}") + + if modify_columns: + # PostgreSQL requires ALTER COLUMN ... TYPE ... for type changes + for col in modify_columns: + col_name = self.quote_identifier(col["name"]) + col_type = col["type"] + nullable = col.get("nullable", False) + clauses.append(f"ALTER COLUMN {col_name} TYPE {col_type}") + if nullable: + clauses.append(f"ALTER COLUMN {col_name} DROP NOT NULL") + else: + clauses.append(f"ALTER COLUMN {col_name} SET NOT NULL") + + return f"ALTER TABLE {table_name} {', '.join(clauses)}" + + def add_comment_sql( + self, + object_type: str, + object_name: str, + comment: str, + ) -> str | None: + """ + Generate COMMENT ON statement for PostgreSQL. + + Parameters + ---------- + object_type : str + 'table' or 'column'. + object_name : str + Fully qualified object name. + comment : str + Comment text. + + Returns + ------- + str + COMMENT ON statement. + """ + comment_type = object_type.upper() + return f"COMMENT ON {comment_type} {object_name} IS {self.quote_string(comment)}" + + # ========================================================================= + # DML Generation + # ========================================================================= + + def insert_sql( + self, + table_name: str, + columns: list[str], + on_duplicate: str | None = None, + ) -> str: + """ + Generate INSERT statement for PostgreSQL. + + Parameters + ---------- + table_name : str + Table name. + columns : list[str] + Column names. + on_duplicate : str, optional + 'ignore' or 'update' (PostgreSQL uses ON CONFLICT). + + Returns + ------- + str + INSERT SQL with placeholders. + """ + cols = ", ".join(self.quote_identifier(col) for col in columns) + placeholders = ", ".join([self.parameter_placeholder] * len(columns)) + + base_insert = f"INSERT INTO {table_name} ({cols}) VALUES ({placeholders})" + + if on_duplicate == "ignore": + return f"{base_insert} ON CONFLICT DO NOTHING" + elif on_duplicate == "update": + # ON CONFLICT (pk_cols) DO UPDATE SET col=EXCLUDED.col + # Caller must provide constraint name or columns + updates = ", ".join(f"{self.quote_identifier(col)}=EXCLUDED.{self.quote_identifier(col)}" for col in columns) + return f"{base_insert} ON CONFLICT DO UPDATE SET {updates}" + else: + return base_insert + + def update_sql( + self, + table_name: str, + set_columns: list[str], + where_columns: list[str], + ) -> str: + """Generate UPDATE statement for PostgreSQL.""" + set_clause = ", ".join(f"{self.quote_identifier(col)} = {self.parameter_placeholder}" for col in set_columns) + where_clause = " AND ".join(f"{self.quote_identifier(col)} = {self.parameter_placeholder}" for col in where_columns) + return f"UPDATE {table_name} SET {set_clause} WHERE {where_clause}" + + def delete_sql(self, table_name: str) -> str: + """Generate DELETE statement for PostgreSQL (WHERE added separately).""" + return f"DELETE FROM {table_name}" + + def upsert_on_duplicate_sql( + self, + table_name: str, + columns: list[str], + primary_key: list[str], + num_rows: int, + ) -> str: + """Generate INSERT ... ON CONFLICT ... DO UPDATE statement for PostgreSQL.""" + # Build column list + col_list = ", ".join(columns) + + # Build placeholders for VALUES + placeholders = ", ".join(["(%s)" % ", ".join(["%s"] * len(columns))] * num_rows) + + # Build conflict target (primary key columns) + conflict_cols = ", ".join(primary_key) + + # Build UPDATE clause (non-PK columns only) + non_pk_columns = [col for col in columns if col not in primary_key] + update_clauses = ", ".join(f"{col} = EXCLUDED.{col}" for col in non_pk_columns) + + return f""" + INSERT INTO {table_name} ({col_list}) + VALUES {placeholders} + ON CONFLICT ({conflict_cols}) DO UPDATE SET {update_clauses} + """ + + def skip_duplicates_clause( + self, + full_table_name: str, + primary_key: list[str], + ) -> str: + """ + Generate clause to skip duplicate key insertions for PostgreSQL. + + Uses ON CONFLICT (pk_cols) DO NOTHING to skip duplicates without + raising an error. + + Parameters + ---------- + full_table_name : str + Fully qualified table name (with quotes). Unused but kept for + API compatibility with MySQL adapter. + primary_key : list[str] + Primary key column names (unquoted). + + Returns + ------- + str + PostgreSQL ON CONFLICT DO NOTHING clause. + """ + pk_cols = ", ".join(self.quote_identifier(pk) for pk in primary_key) + return f" ON CONFLICT ({pk_cols}) DO NOTHING" + + @property + def supports_inline_indexes(self) -> bool: + """ + PostgreSQL does not support inline INDEX in CREATE TABLE. + + Returns False to indicate indexes must be created separately + with CREATE INDEX statements. + """ + return False + + # ========================================================================= + # Introspection + # ========================================================================= + + def list_schemas_sql(self) -> str: + """Query to list all schemas in PostgreSQL.""" + return ( + "SELECT schema_name FROM information_schema.schemata " + "WHERE schema_name NOT IN ('pg_catalog', 'information_schema')" + ) + + def list_tables_sql(self, schema_name: str, pattern: str | None = None) -> str: + """Query to list tables in a schema.""" + sql = ( + f"SELECT table_name FROM information_schema.tables " + f"WHERE table_schema = {self.quote_string(schema_name)} " + f"AND table_type = 'BASE TABLE'" + ) + if pattern: + sql += f" AND table_name LIKE '{pattern}'" + return sql + + def get_table_info_sql(self, schema_name: str, table_name: str) -> str: + """Query to get table metadata including table comment.""" + schema_str = self.quote_string(schema_name) + table_str = self.quote_string(table_name) + regclass_expr = f"({schema_str} || '.' || {table_str})::regclass" + return ( + f"SELECT t.*, obj_description({regclass_expr}, 'pg_class') as table_comment " + f"FROM information_schema.tables t " + f"WHERE t.table_schema = {schema_str} " + f"AND t.table_name = {table_str}" + ) + + def get_columns_sql(self, schema_name: str, table_name: str) -> str: + """Query to get column definitions including comments.""" + # Use col_description() to retrieve column comments stored via COMMENT ON COLUMN + # The regclass cast allows using schema.table notation to get the OID + schema_str = self.quote_string(schema_name) + table_str = self.quote_string(table_name) + regclass_expr = f"({schema_str} || '.' || {table_str})::regclass" + return ( + f"SELECT c.column_name, c.data_type, c.udt_name, c.is_nullable, c.column_default, " + f"c.character_maximum_length, c.numeric_precision, c.numeric_scale, " + f"col_description({regclass_expr}, c.ordinal_position) as column_comment " + f"FROM information_schema.columns c " + f"WHERE c.table_schema = {schema_str} " + f"AND c.table_name = {table_str} " + f"ORDER BY c.ordinal_position" + ) + + def get_primary_key_sql(self, schema_name: str, table_name: str) -> str: + """Query to get primary key columns.""" + return ( + f"SELECT column_name FROM information_schema.key_column_usage " + f"WHERE table_schema = {self.quote_string(schema_name)} " + f"AND table_name = {self.quote_string(table_name)} " + f"AND constraint_name IN (" + f" SELECT constraint_name FROM information_schema.table_constraints " + f" WHERE table_schema = {self.quote_string(schema_name)} " + f" AND table_name = {self.quote_string(table_name)} " + f" AND constraint_type = 'PRIMARY KEY'" + f") " + f"ORDER BY ordinal_position" + ) + + def get_foreign_keys_sql(self, schema_name: str, table_name: str) -> str: + """Query to get foreign key constraints.""" + return ( + f"SELECT kcu.constraint_name, kcu.column_name, " + f"ccu.table_name AS foreign_table_name, ccu.column_name AS foreign_column_name " + f"FROM information_schema.key_column_usage AS kcu " + f"JOIN information_schema.constraint_column_usage AS ccu " + f" ON kcu.constraint_name = ccu.constraint_name " + f"WHERE kcu.table_schema = {self.quote_string(schema_name)} " + f"AND kcu.table_name = {self.quote_string(table_name)} " + f"AND kcu.constraint_name IN (" + f" SELECT constraint_name FROM information_schema.table_constraints " + f" WHERE table_schema = {self.quote_string(schema_name)} " + f" AND table_name = {self.quote_string(table_name)} " + f" AND constraint_type = 'FOREIGN KEY'" + f") " + f"ORDER BY kcu.constraint_name, kcu.ordinal_position" + ) + + def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_name: str) -> str: + """ + Query to get FK constraint details from information_schema. + + Returns matched pairs of (fk_column, parent_table, pk_column) for each + column in the foreign key constraint, ordered by position. + """ + return ( + "SELECT " + " kcu.column_name as fk_attrs, " + " '\"' || ccu.table_schema || '\".\"' || ccu.table_name || '\"' as parent, " + " ccu.column_name as pk_attrs " + "FROM information_schema.key_column_usage AS kcu " + "JOIN information_schema.referential_constraints AS rc " + " ON kcu.constraint_name = rc.constraint_name " + " AND kcu.constraint_schema = rc.constraint_schema " + "JOIN information_schema.key_column_usage AS ccu " + " ON rc.unique_constraint_name = ccu.constraint_name " + " AND rc.unique_constraint_schema = ccu.constraint_schema " + " AND kcu.ordinal_position = ccu.ordinal_position " + "WHERE kcu.constraint_name = %s " + " AND kcu.table_schema = %s " + " AND kcu.table_name = %s " + "ORDER BY kcu.ordinal_position" + ) + + def parse_foreign_key_error(self, error_message: str) -> dict[str, str | list[str] | None] | None: + """ + Parse PostgreSQL foreign key violation error message. + + PostgreSQL FK error format: + 'update or delete on table "X" violates foreign key constraint "Y" on table "Z"' + Where: + - "X" is the referenced table (being deleted/updated) + - "Z" is the referencing table (has the FK, needs cascade delete) + """ + import re + + pattern = re.compile( + r'.*table "(?P[^"]+)" violates foreign key constraint ' + r'"(?P[^"]+)" on table "(?P[^"]+)"' + ) + + match = pattern.match(error_message) + if not match: + return None + + result = match.groupdict() + + # The child is the referencing table (the one with the FK that needs cascade delete) + # The parent is the referenced table (the one being deleted) + # The error doesn't include schema, so we return unqualified names + child = f'"{result["referencing_table"]}"' + parent = f'"{result["referenced_table"]}"' + + return { + "child": child, + "name": f'"{result["name"]}"', + "fk_attrs": None, # Not in error message, will need constraint query + "parent": parent, + "pk_attrs": None, # Not in error message, will need constraint query + } + + def get_indexes_sql(self, schema_name: str, table_name: str) -> str: + """Query to get index definitions.""" + return ( + f"SELECT indexname, indexdef FROM pg_indexes " + f"WHERE schemaname = {self.quote_string(schema_name)} " + f"AND tablename = {self.quote_string(table_name)}" + ) + + def parse_column_info(self, row: dict[str, Any]) -> dict[str, Any]: + """ + Parse PostgreSQL column info into standardized format. + + Parameters + ---------- + row : dict + Row from information_schema.columns query with col_description() join. + + Returns + ------- + dict + Standardized column info with keys: + name, type, nullable, default, comment, key, extra + """ + # For user-defined types (enums), use udt_name instead of data_type + # PostgreSQL reports enums as "USER-DEFINED" in data_type + data_type = row["data_type"] + if data_type == "USER-DEFINED": + data_type = row["udt_name"] + + # Reconstruct parametrized types that PostgreSQL splits into separate fields + char_max_len = row.get("character_maximum_length") + num_precision = row.get("numeric_precision") + num_scale = row.get("numeric_scale") + + if data_type == "character" and char_max_len is not None: + # char(n) - PostgreSQL reports as "character" with length in separate field + data_type = f"char({char_max_len})" + elif data_type == "character varying" and char_max_len is not None: + # varchar(n) + data_type = f"varchar({char_max_len})" + elif data_type == "numeric" and num_precision is not None: + # numeric(p,s) - reconstruct decimal type + if num_scale is not None and num_scale > 0: + data_type = f"decimal({num_precision},{num_scale})" + else: + data_type = f"decimal({num_precision})" + + return { + "name": row["column_name"], + "type": data_type, + "nullable": row["is_nullable"] == "YES", + "default": row["column_default"], + "comment": row.get("column_comment"), # Retrieved via col_description() + "key": "", # PostgreSQL key info retrieved separately + "extra": "", # PostgreSQL doesn't have auto_increment in same way + } + + # ========================================================================= + # Transactions + # ========================================================================= + + def start_transaction_sql(self, isolation_level: str | None = None) -> str: + """Generate BEGIN statement for PostgreSQL.""" + if isolation_level: + return f"BEGIN ISOLATION LEVEL {isolation_level}" + return "BEGIN" + + def commit_sql(self) -> str: + """Generate COMMIT statement.""" + return "COMMIT" + + def rollback_sql(self) -> str: + """Generate ROLLBACK statement.""" + return "ROLLBACK" + + # ========================================================================= + # Functions and Expressions + # ========================================================================= + + def current_timestamp_expr(self, precision: int | None = None) -> str: + """ + CURRENT_TIMESTAMP expression for PostgreSQL. + + Parameters + ---------- + precision : int, optional + Fractional seconds precision (0-6). + + Returns + ------- + str + CURRENT_TIMESTAMP or CURRENT_TIMESTAMP(n). + """ + if precision is not None: + return f"CURRENT_TIMESTAMP({precision})" + return "CURRENT_TIMESTAMP" + + def interval_expr(self, value: int, unit: str) -> str: + """ + INTERVAL expression for PostgreSQL. + + Parameters + ---------- + value : int + Interval value. + unit : str + Time unit (singular: 'second', 'minute', 'hour', 'day'). + + Returns + ------- + str + INTERVAL 'n units' (e.g., "INTERVAL '5 seconds'"). + """ + # PostgreSQL uses plural unit names and quotes + unit_plural = unit.lower() + "s" if not unit.endswith("s") else unit.lower() + return f"INTERVAL '{value} {unit_plural}'" + + def current_user_expr(self) -> str: + """PostgreSQL current user expression.""" + return "current_user" + + def json_path_expr(self, column: str, path: str, return_type: str | None = None) -> str: + """ + Generate PostgreSQL jsonb_extract_path_text() expression. + + Parameters + ---------- + column : str + Column name containing JSON data. + path : str + JSON path (e.g., 'field' or 'nested.field'). + return_type : str, optional + Return type specification for casting (e.g., 'float', 'decimal(10,2)'). + + Returns + ------- + str + PostgreSQL jsonb_extract_path_text() expression, with optional cast. + + Examples + -------- + >>> adapter.json_path_expr('data', 'field') + 'jsonb_extract_path_text("data", \\'field\\')' + >>> adapter.json_path_expr('data', 'nested.field') + 'jsonb_extract_path_text("data", \\'nested\\', \\'field\\')' + >>> adapter.json_path_expr('data', 'value', 'float') + 'jsonb_extract_path_text("data", \\'value\\')::float' + """ + quoted_col = self.quote_identifier(column) + # Split path by '.' for nested access, handling array notation + path_parts = [] + for part in path.split("."): + # Handle array access like field[0] + if "[" in part: + base, rest = part.split("[", 1) + path_parts.append(base) + # Extract array indices + indices = rest.rstrip("]").split("][") + path_parts.extend(indices) + else: + path_parts.append(part) + path_args = ", ".join(f"'{part}'" for part in path_parts) + expr = f"jsonb_extract_path_text({quoted_col}, {path_args})" + # Add cast if return type specified + if return_type: + # Map DataJoint types to PostgreSQL types + pg_type = return_type.lower() + if pg_type in ("unsigned", "signed"): + pg_type = "integer" + elif pg_type == "double": + pg_type = "double precision" + expr = f"({expr})::{pg_type}" + return expr + + def translate_expression(self, expr: str) -> str: + """ + Translate SQL expression for PostgreSQL compatibility. + + Converts MySQL-specific functions to PostgreSQL equivalents: + - GROUP_CONCAT(col) → STRING_AGG(col::text, ',') + - GROUP_CONCAT(col SEPARATOR 'sep') → STRING_AGG(col::text, 'sep') + + Parameters + ---------- + expr : str + SQL expression that may contain function calls. + + Returns + ------- + str + Translated expression for PostgreSQL. + """ + import re + + # GROUP_CONCAT(col) → STRING_AGG(col::text, ',') + # GROUP_CONCAT(col SEPARATOR 'sep') → STRING_AGG(col::text, 'sep') + def replace_group_concat(match): + inner = match.group(1).strip() + # Check for SEPARATOR clause + sep_match = re.match(r"(.+?)\s+SEPARATOR\s+(['\"])(.+?)\2", inner, re.IGNORECASE) + if sep_match: + col = sep_match.group(1).strip() + sep = sep_match.group(3) + return f"STRING_AGG({col}::text, '{sep}')" + else: + return f"STRING_AGG({inner}::text, ',')" + + expr = re.sub(r"GROUP_CONCAT\s*\((.+?)\)", replace_group_concat, expr, flags=re.IGNORECASE) + + # Replace simple functions FIRST before complex patterns + # CURDATE() → CURRENT_DATE + expr = re.sub(r"CURDATE\s*\(\s*\)", "CURRENT_DATE", expr, flags=re.IGNORECASE) + + # NOW() → CURRENT_TIMESTAMP + expr = re.sub(r"\bNOW\s*\(\s*\)", "CURRENT_TIMESTAMP", expr, flags=re.IGNORECASE) + + # YEAR(date) → EXTRACT(YEAR FROM date)::int + expr = re.sub(r"\bYEAR\s*\(\s*([^)]+)\s*\)", r"EXTRACT(YEAR FROM \1)::int", expr, flags=re.IGNORECASE) + + # MONTH(date) → EXTRACT(MONTH FROM date)::int + expr = re.sub(r"\bMONTH\s*\(\s*([^)]+)\s*\)", r"EXTRACT(MONTH FROM \1)::int", expr, flags=re.IGNORECASE) + + # DAY(date) → EXTRACT(DAY FROM date)::int + expr = re.sub(r"\bDAY\s*\(\s*([^)]+)\s*\)", r"EXTRACT(DAY FROM \1)::int", expr, flags=re.IGNORECASE) + + # TIMESTAMPDIFF(YEAR, d1, d2) → EXTRACT(YEAR FROM AGE(d2, d1))::int + # Use a more robust regex that handles the comma-separated arguments + def replace_timestampdiff(match): + unit = match.group(1).upper() + date1 = match.group(2).strip() + date2 = match.group(3).strip() + if unit == "YEAR": + return f"EXTRACT(YEAR FROM AGE({date2}, {date1}))::int" + elif unit == "MONTH": + return f"(EXTRACT(YEAR FROM AGE({date2}, {date1})) * 12 + EXTRACT(MONTH FROM AGE({date2}, {date1})))::int" + elif unit == "DAY": + return f"({date2}::date - {date1}::date)" + else: + return f"EXTRACT({unit} FROM AGE({date2}, {date1}))::int" + + # Match TIMESTAMPDIFF with proper argument parsing + # The arguments are: unit, date1, date2 - we need to handle identifiers and CURRENT_DATE + expr = re.sub( + r"TIMESTAMPDIFF\s*\(\s*(\w+)\s*,\s*([^,]+)\s*,\s*([^)]+)\s*\)", + replace_timestampdiff, + expr, + flags=re.IGNORECASE, + ) + + # SUM(expr='value') → SUM((expr='value')::int) for PostgreSQL boolean handling + # This handles patterns like SUM(sex='F') which produce boolean in PostgreSQL + def replace_sum_comparison(match): + inner = match.group(1).strip() + # Check if inner contains a comparison operator + if re.search(r"[=<>!]", inner) and not inner.startswith("("): + return f"SUM(({inner})::int)" + return match.group(0) # Return unchanged if no comparison + + expr = re.sub(r"\bSUM\s*\(\s*([^)]+)\s*\)", replace_sum_comparison, expr, flags=re.IGNORECASE) + + return expr + + # ========================================================================= + # DDL Generation + # ========================================================================= + + def format_column_definition( + self, + name: str, + sql_type: str, + nullable: bool = False, + default: str | None = None, + comment: str | None = None, + ) -> str: + """ + Format a column definition for PostgreSQL DDL. + + Examples + -------- + >>> adapter.format_column_definition('user_id', 'bigint', nullable=False, comment='user ID') + '"user_id" bigint NOT NULL' + """ + parts = [self.quote_identifier(name), sql_type] + if default: + parts.append(default) + elif not nullable: + parts.append("NOT NULL") + # Note: PostgreSQL comments handled separately via COMMENT ON + return " ".join(parts) + + def table_options_clause(self, comment: str | None = None) -> str: + """ + Generate PostgreSQL table options clause (empty - no ENGINE in PostgreSQL). + + Examples + -------- + >>> adapter.table_options_clause('test table') + '' + >>> adapter.table_options_clause() + '' + """ + return "" # PostgreSQL uses COMMENT ON TABLE separately + + def table_comment_ddl(self, full_table_name: str, comment: str) -> str | None: + """ + Generate COMMENT ON TABLE statement for PostgreSQL. + + Examples + -------- + >>> adapter.table_comment_ddl('"schema"."table"', 'test comment') + 'COMMENT ON TABLE "schema"."table" IS \\'test comment\\'' + """ + # Escape single quotes by doubling them + escaped_comment = comment.replace("'", "''") + return f"COMMENT ON TABLE {full_table_name} IS '{escaped_comment}'" + + def column_comment_ddl(self, full_table_name: str, column_name: str, comment: str) -> str | None: + """ + Generate COMMENT ON COLUMN statement for PostgreSQL. + + Examples + -------- + >>> adapter.column_comment_ddl('"schema"."table"', 'column', 'test comment') + 'COMMENT ON COLUMN "schema"."table"."column" IS \\'test comment\\'' + """ + quoted_col = self.quote_identifier(column_name) + # Escape single quotes by doubling them (PostgreSQL string literal syntax) + escaped_comment = comment.replace("'", "''") + return f"COMMENT ON COLUMN {full_table_name}.{quoted_col} IS '{escaped_comment}'" + + def enum_type_ddl(self, type_name: str, values: list[str]) -> str | None: + """ + Generate CREATE TYPE statement for PostgreSQL enum. + + Examples + -------- + >>> adapter.enum_type_ddl('status_type', ['active', 'inactive']) + 'CREATE TYPE "status_type" AS ENUM (\\'active\\', \\'inactive\\')' + """ + quoted_values = ", ".join(f"'{v}'" for v in values) + return f"CREATE TYPE {self.quote_identifier(type_name)} AS ENUM ({quoted_values})" + + def get_pending_enum_ddl(self, schema_name: str) -> list[str]: + """ + Get DDL statements for pending enum types and clear the pending list. + + PostgreSQL requires CREATE TYPE statements before using enum types in + column definitions. This method returns DDL for enum types accumulated + during type conversion and clears the pending list. + + Parameters + ---------- + schema_name : str + Schema name to qualify enum type names. + + Returns + ------- + list[str] + List of CREATE TYPE statements (if any pending). + """ + ddl_statements = [] + if hasattr(self, "_pending_enum_types") and self._pending_enum_types: + for type_name, values in self._pending_enum_types.items(): + # Generate CREATE TYPE with schema qualification + quoted_type = f"{self.quote_identifier(schema_name)}.{self.quote_identifier(type_name)}" + quoted_values = ", ".join(f"'{v}'" for v in values) + ddl_statements.append(f"CREATE TYPE {quoted_type} AS ENUM ({quoted_values})") + self._pending_enum_types = {} + return ddl_statements + + def job_metadata_columns(self) -> list[str]: + """ + Return PostgreSQL-specific job metadata column definitions. + + Examples + -------- + >>> adapter.job_metadata_columns() + ['"_job_start_time" timestamp DEFAULT NULL', + '"_job_duration" real DEFAULT NULL', + '"_job_version" varchar(64) DEFAULT \\'\\''] + """ + return [ + '"_job_start_time" timestamp DEFAULT NULL', + '"_job_duration" real DEFAULT NULL', + "\"_job_version\" varchar(64) DEFAULT ''", + ] + + # ========================================================================= + # Error Translation + # ========================================================================= + + def translate_error(self, error: Exception, query: str = "") -> Exception: + """ + Translate PostgreSQL error to DataJoint exception. + + Parameters + ---------- + error : Exception + PostgreSQL exception (typically psycopg2 error). + query : str, optional + SQL query that caused the error (for context). + + Returns + ------- + Exception + DataJoint exception or original error. + """ + if not hasattr(error, "pgcode"): + return error + + pgcode = error.pgcode + + # PostgreSQL error code mapping + # Reference: https://www.postgresql.org/docs/current/errcodes-appendix.html + match pgcode: + # Integrity constraint violations + case "23505": # unique_violation + return errors.DuplicateError(str(error)) + case "23503": # foreign_key_violation + return errors.IntegrityError(str(error)) + case "23502": # not_null_violation + return errors.MissingAttributeError(str(error)) + + # Syntax errors + case "42601": # syntax_error + return errors.QuerySyntaxError(str(error), "") + + # Undefined errors + case "42P01": # undefined_table + return errors.MissingTableError(str(error), "") + case "42703": # undefined_column + return errors.UnknownAttributeError(str(error)) + + # Connection errors + case "08006" | "08003" | "08000": # connection_failure + return errors.LostConnectionError(str(error)) + case "57P01": # admin_shutdown + return errors.LostConnectionError(str(error)) + + # Access errors + case "42501": # insufficient_privilege + return errors.AccessError("Insufficient privileges.", str(error), "") + + # All other errors pass through unchanged + case _: + return error + + # ========================================================================= + # Native Type Validation + # ========================================================================= + + def validate_native_type(self, type_str: str) -> bool: + """ + Check if a native PostgreSQL type string is valid. + + Parameters + ---------- + type_str : str + Type string to validate. + + Returns + ------- + bool + True if valid PostgreSQL type. + """ + type_lower = type_str.lower().strip() + + # PostgreSQL native types (simplified validation) + valid_types = { + # Integer types + "smallint", + "integer", + "int", + "bigint", + "smallserial", + "serial", + "bigserial", + # Floating point + "real", + "double precision", + "numeric", + "decimal", + # String types + "char", + "varchar", + "text", + # Binary + "bytea", + # Boolean + "boolean", + "bool", + # Temporal types + "date", + "time", + "timetz", + "timestamp", + "timestamptz", + "interval", + # UUID + "uuid", + # JSON + "json", + "jsonb", + # Network types + "inet", + "cidr", + "macaddr", + # Geometric types + "point", + "line", + "lseg", + "box", + "path", + "polygon", + "circle", + # Other + "money", + "xml", + } + + # Extract base type (before parentheses or brackets) + base_type = type_lower.split("(")[0].split("[")[0].strip() + + return base_type in valid_types + + # ========================================================================= + # PostgreSQL-Specific Enum Handling + # ========================================================================= + + def create_enum_type_sql( + self, + schema: str, + table: str, + column: str, + values: list[str], + ) -> str: + """ + Generate CREATE TYPE statement for PostgreSQL enum. + + Parameters + ---------- + schema : str + Schema name. + table : str + Table name. + column : str + Column name. + values : list[str] + Enum values. + + Returns + ------- + str + CREATE TYPE ... AS ENUM statement. + """ + type_name = f"{schema}_{table}_{column}_enum" + quoted_values = ", ".join(self.quote_string(v) for v in values) + return f"CREATE TYPE {self.quote_identifier(type_name)} AS ENUM ({quoted_values})" + + def drop_enum_type_sql(self, schema: str, table: str, column: str) -> str: + """ + Generate DROP TYPE statement for PostgreSQL enum. + + Parameters + ---------- + schema : str + Schema name. + table : str + Table name. + column : str + Column name. + + Returns + ------- + str + DROP TYPE statement. + """ + type_name = f"{schema}_{table}_{column}_enum" + return f"DROP TYPE IF EXISTS {self.quote_identifier(type_name)} CASCADE" + + def get_table_enum_types_sql(self, schema_name: str, table_name: str) -> str: + """ + Query to get enum types used by a table's columns. + + Parameters + ---------- + schema_name : str + Schema name. + table_name : str + Table name. + + Returns + ------- + str + SQL query that returns enum type names (schema-qualified). + """ + return f""" + SELECT DISTINCT + n.nspname || '.' || t.typname as enum_type + FROM pg_catalog.pg_type t + JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace + JOIN pg_catalog.pg_attribute a ON a.atttypid = t.oid + JOIN pg_catalog.pg_class c ON c.oid = a.attrelid + JOIN pg_catalog.pg_namespace cn ON cn.oid = c.relnamespace + WHERE t.typtype = 'e' + AND cn.nspname = {self.quote_string(schema_name)} + AND c.relname = {self.quote_string(table_name)} + """ + + def drop_enum_types_for_table(self, schema_name: str, table_name: str) -> list[str]: + """ + Generate DROP TYPE statements for all enum types used by a table. + + Parameters + ---------- + schema_name : str + Schema name. + table_name : str + Table name. + + Returns + ------- + list[str] + List of DROP TYPE IF EXISTS statements. + """ + # Returns list of DDL statements - caller should execute query first + # to get actual enum types, then call this with results + return [] # Placeholder - actual implementation requires query execution + + def drop_enum_type_ddl(self, enum_type_name: str) -> str: + """ + Generate DROP TYPE IF EXISTS statement for a PostgreSQL enum. + + Parameters + ---------- + enum_type_name : str + Fully qualified enum type name (schema.typename). + + Returns + ------- + str + DROP TYPE IF EXISTS statement with CASCADE. + """ + # Split schema.typename and quote each part + parts = enum_type_name.split(".") + if len(parts) == 2: + qualified_name = f"{self.quote_identifier(parts[0])}.{self.quote_identifier(parts[1])}" + else: + qualified_name = self.quote_identifier(enum_type_name) + return f"DROP TYPE IF EXISTS {qualified_name} CASCADE" diff --git a/src/datajoint/autopopulate.py b/src/datajoint/autopopulate.py index b40ebbda4..244a2dd53 100644 --- a/src/datajoint/autopopulate.py +++ b/src/datajoint/autopopulate.py @@ -432,7 +432,9 @@ def _populate_direct( else: # spawn multiple processes self.connection.close() - del self.connection._conn.ctx # SSLContext is not pickleable + # Remove SSLContext if present (MySQL-specific, not pickleable) + if hasattr(self.connection._conn, "ctx"): + del self.connection._conn.ctx with ( mp.Pool(processes, _initialize_populate, (self, None, populate_kwargs)) as pool, tqdm(desc="Processes: ", total=nkeys) if display_progress else contextlib.nullcontext() as progress_bar, @@ -522,7 +524,9 @@ def handler(signum, frame): else: # spawn multiple processes self.connection.close() - del self.connection._conn.ctx # SSLContext is not pickleable + # Remove SSLContext if present (MySQL-specific, not pickleable) + if hasattr(self.connection._conn, "ctx"): + del self.connection._conn.ctx with ( mp.Pool(processes, _initialize_populate, (self, self.jobs, populate_kwargs)) as pool, tqdm(desc="Processes: ", total=nkeys) @@ -699,17 +703,26 @@ def progress(self, *restrictions: Any, display: bool = False) -> tuple[int, int] todo_sql = todo.make_sql() target_sql = self.make_sql() + # Get adapter for backend-specific quoting + adapter = self.connection.adapter + q = adapter.quote_identifier + + # Alias names for subqueries + ks_alias = q("$ks") + tgt_alias = q("$tgt") + # Build join condition on common attributes - join_cond = " AND ".join(f"`$ks`.`{attr}` = `$tgt`.`{attr}`" for attr in common_attrs) + join_cond = " AND ".join(f"{ks_alias}.{q(attr)} = {tgt_alias}.{q(attr)}" for attr in common_attrs) # Build DISTINCT key expression for counting unique jobs - # Use CONCAT for composite keys to create a single distinct value + # Use CONCAT_WS for composite keys (supported by both MySQL and PostgreSQL) if len(pk_attrs) == 1: - distinct_key = f"`$ks`.`{pk_attrs[0]}`" - null_check = f"`$tgt`.`{common_attrs[0]}`" + distinct_key = f"{ks_alias}.{q(pk_attrs[0])}" + null_check = f"{tgt_alias}.{q(common_attrs[0])}" else: - distinct_key = "CONCAT_WS('|', {})".format(", ".join(f"`$ks`.`{attr}`" for attr in pk_attrs)) - null_check = f"`$tgt`.`{common_attrs[0]}`" + key_cols = ", ".join(f"{ks_alias}.{q(attr)}" for attr in pk_attrs) + distinct_key = f"CONCAT_WS('|', {key_cols})" + null_check = f"{tgt_alias}.{q(common_attrs[0])}" # Single aggregation query: # - COUNT(DISTINCT key) gives total unique jobs in key_source @@ -718,8 +731,8 @@ def progress(self, *restrictions: Any, display: bool = False) -> tuple[int, int] SELECT COUNT(DISTINCT {distinct_key}) AS total, COUNT(DISTINCT CASE WHEN {null_check} IS NULL THEN {distinct_key} END) AS remaining - FROM ({todo_sql}) AS `$ks` - LEFT JOIN ({target_sql}) AS `$tgt` ON {join_cond} + FROM ({todo_sql}) AS {ks_alias} + LEFT JOIN ({target_sql}) AS {tgt_alias} ON {join_cond} """ result = self.connection.query(sql).fetchone() diff --git a/src/datajoint/blob.py b/src/datajoint/blob.py index d94417d6d..633f55b79 100644 --- a/src/datajoint/blob.py +++ b/src/datajoint/blob.py @@ -149,6 +149,9 @@ def squeeze(self, array: np.ndarray, convert_to_scalar: bool = True) -> np.ndarr return array.item() if array.ndim == 0 and convert_to_scalar else array def unpack(self, blob): + # PostgreSQL returns bytea as memoryview; convert to bytes for string operations + if isinstance(blob, memoryview): + blob = bytes(blob) self._blob = blob try: # decompress diff --git a/src/datajoint/codecs.py b/src/datajoint/codecs.py index afa60321f..5c192d46e 100644 --- a/src/datajoint/codecs.py +++ b/src/datajoint/codecs.py @@ -544,7 +544,9 @@ def decode_attribute(attr, data, squeeze: bool = False): # Process the final storage type (what's in the database) if final_dtype.lower() == "json": - data = json.loads(data) + # psycopg2 auto-deserializes JSON to dict/list; only parse strings + if isinstance(data, str): + data = json.loads(data) elif final_dtype.lower() in ("longblob", "blob", "mediumblob", "tinyblob"): pass # Blob data is already bytes elif final_dtype.lower() == "binary(16)": @@ -562,7 +564,10 @@ def decode_attribute(attr, data, squeeze: bool = False): # No codec - handle native types if attr.json: - return json.loads(data) + # psycopg2 auto-deserializes JSON to dict/list; only parse strings + if isinstance(data, str): + return json.loads(data) + return data if attr.uuid: import uuid as uuid_module diff --git a/src/datajoint/condition.py b/src/datajoint/condition.py index 9c6f933d1..0335d6adb 100644 --- a/src/datajoint/condition.py +++ b/src/datajoint/condition.py @@ -31,7 +31,7 @@ JSON_PATTERN = re.compile(r"^(?P\w+)(\.(?P[\w.*\[\]]+))?(:(?P[\w(,\s)]+))?$") -def translate_attribute(key: str) -> tuple[dict | None, str]: +def translate_attribute(key: str, adapter=None) -> tuple[dict | None, str]: """ Translate an attribute key, handling JSON path notation. @@ -39,6 +39,9 @@ def translate_attribute(key: str) -> tuple[dict | None, str]: ---------- key : str Attribute name, optionally with JSON path (e.g., ``"attr.path.field"``). + adapter : DatabaseAdapter, optional + Database adapter for backend-specific SQL generation. + If not provided, uses MySQL syntax for backward compatibility. Returns ------- @@ -53,9 +56,14 @@ def translate_attribute(key: str) -> tuple[dict | None, str]: if match["path"] is None: return match, match["attr"] else: - return match, "json_value(`{}`, _utf8mb4'$.{}'{})".format( - *[((f" returning {v}" if k == "type" else v) if v else "") for k, v in match.items()] - ) + # Use adapter's json_path_expr if available, otherwise fall back to MySQL syntax + if adapter is not None: + return match, adapter.json_path_expr(match["attr"], match["path"], match["type"]) + else: + # Legacy MySQL syntax for backward compatibility + return match, "json_value(`{}`, _utf8mb4'$.{}'{})".format( + *[((f" returning {v}" if k == "type" else v) if v else "") for k, v in match.items()] + ) class PromiscuousOperand: @@ -301,16 +309,21 @@ def make_condition( """ from .expression import Aggregation, QueryExpression, U + # Get adapter for backend-agnostic SQL generation + adapter = query_expression.connection.adapter + def prep_value(k, v): """prepare SQL condition""" - key_match, k = translate_attribute(k) - if key_match["path"] is None: - k = f"`{k}`" - if query_expression.heading[key_match["attr"]].json and key_match["path"] is not None and isinstance(v, dict): + key_match, k = translate_attribute(k, adapter) + is_json_path = key_match is not None and key_match.get("path") is not None + + if not is_json_path: + k = adapter.quote_identifier(k) + if is_json_path and isinstance(v, dict): return f"{k}='{json.dumps(v)}'" if v is None: return f"{k} IS NULL" - if query_expression.heading[key_match["attr"]].uuid: + if key_match is not None and query_expression.heading[key_match["attr"]].uuid: if not isinstance(v, uuid.UUID): try: v = uuid.UUID(v) @@ -327,10 +340,12 @@ def prep_value(k, v): list, ), ): - return f'{k}="{v}"' + # Use single quotes for string literals (works for both MySQL and PostgreSQL) + return f"{k}='{v}'" if isinstance(v, str): - v = v.replace("%", "%%").replace("\\", "\\\\") - return f'{k}="{v}"' + # Escape single quotes by doubling them, and escape % for driver + v = v.replace("'", "''").replace("%", "%%").replace("\\", "\\\\") + return f"{k}='{v}'" return f"{k}={v}" def combine_conditions(negate, conditions): @@ -410,10 +425,12 @@ def combine_conditions(negate, conditions): # without common attributes, any non-empty set matches everything (not negate if condition else negate) if not common_attributes - else "({fields}) {not_}in ({subquery})".format( - fields="`" + "`,`".join(common_attributes) + "`", - not_="not " if negate else "", - subquery=condition.make_sql(common_attributes), + else ( + "({fields}) {not_}in ({subquery})".format( + fields=", ".join(adapter.quote_identifier(a) for a in common_attributes), + not_="not " if negate else "", + subquery=condition.make_sql(common_attributes), + ) ) ) diff --git a/src/datajoint/connection.py b/src/datajoint/connection.py index 43dd43fa8..92680e0d2 100644 --- a/src/datajoint/connection.py +++ b/src/datajoint/connection.py @@ -14,9 +14,8 @@ from getpass import getpass from typing import Callable -import pymysql as client - from . import errors +from .adapters import get_adapter from .blob import pack, unpack from .dependencies import Dependencies from .settings import config @@ -29,7 +28,7 @@ cache_key = "query_cache" # the key to lookup the query_cache folder in dj.config -def translate_query_error(client_error: Exception, query: str) -> Exception: +def translate_query_error(client_error: Exception, query: str, adapter) -> Exception: """ Translate client error to the corresponding DataJoint exception. @@ -39,6 +38,8 @@ def translate_query_error(client_error: Exception, query: str) -> Exception: The exception raised by the client interface. query : str SQL query with placeholders. + adapter : DatabaseAdapter + The database adapter instance. Returns ------- @@ -47,47 +48,7 @@ def translate_query_error(client_error: Exception, query: str) -> Exception: or the original error if no mapping exists. """ logger.debug("type: {}, args: {}".format(type(client_error), client_error.args)) - - err, *args = client_error.args - - match err: - # Loss of connection errors - case 0 | "(0, '')": - return errors.LostConnectionError("Server connection lost due to an interface error.", *args) - case 2006: - return errors.LostConnectionError("Connection timed out", *args) - case 2013: - return errors.LostConnectionError("Server connection lost", *args) - - # Access errors - case 1044 | 1142: - return errors.AccessError("Insufficient privileges.", args[0], query) - - # Integrity errors - case 1062: - return errors.DuplicateError(*args) - case 1217 | 1451 | 1452 | 3730: - # 1217: Cannot delete parent row (FK constraint) - # 1451: Cannot delete/update parent row (FK constraint) - # 1452: Cannot add/update child row (FK constraint) - # 3730: Cannot drop table referenced by FK constraint - return errors.IntegrityError(*args) - - # Syntax errors - case 1064: - return errors.QuerySyntaxError(args[0], query) - - # Existence errors - case 1146: - return errors.MissingTableError(args[0], query) - case 1364: - return errors.MissingAttributeError(*args) - case 1054: - return errors.UnknownAttributeError(*args) - - # All other errors pass through unchanged - case _: - return client_error + return adapter.translate_error(client_error, query) def conn( @@ -211,15 +172,29 @@ def __init__( port = config["database.port"] self.conn_info = dict(host=host, port=port, user=user, passwd=password) if use_tls is not False: - self.conn_info["ssl"] = use_tls if isinstance(use_tls, dict) else {"ssl": {}} + # use_tls can be: None (auto-detect), True (enable), False (disable), or dict (custom config) + if isinstance(use_tls, dict): + self.conn_info["ssl"] = use_tls + elif use_tls is None: + # Auto-detect: try SSL, fallback to non-SSL if server doesn't support it + self.conn_info["ssl"] = True + else: + # use_tls=True: enable SSL with default settings + self.conn_info["ssl"] = True self.conn_info["ssl_input"] = use_tls self.init_fun = init_fun self._conn = None self._query_cache = None + self._is_closed = True # Mark as closed until connect() succeeds + + # Select adapter based on configured backend + backend = config["database.backend"] + self.adapter = get_adapter(backend) + self.connect() if self.is_connected: logger.info("DataJoint {version} connected to {user}@{host}:{port}".format(version=__version__, **self.conn_info)) - self.connection_id = self.query("SELECT connection_id()").fetchone()[0] + self.connection_id = self.adapter.get_connection_id(self._conn) else: raise errors.LostConnectionError("Connection failed {user}@{host}:{port}".format(**self.conn_info)) self._in_transaction = False @@ -238,26 +213,36 @@ def connect(self) -> None: with warnings.catch_warnings(): warnings.filterwarnings("ignore", ".*deprecated.*") try: - self._conn = client.connect( + # Use adapter to create connection + self._conn = self.adapter.connect( + host=self.conn_info["host"], + port=self.conn_info["port"], + user=self.conn_info["user"], + password=self.conn_info["passwd"], init_command=self.init_fun, - sql_mode="NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO," - "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY", charset=config["connection.charset"], - **{k: v for k, v in self.conn_info.items() if k not in ["ssl_input"]}, + use_tls=self.conn_info.get("ssl"), ) - except client.err.InternalError: - self._conn = client.connect( - init_command=self.init_fun, - sql_mode="NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO," - "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY", - charset=config["connection.charset"], - **{ - k: v - for k, v in self.conn_info.items() - if not (k == "ssl_input" or k == "ssl" and self.conn_info["ssl_input"] is None) - }, - ) - self._conn.autocommit(True) + except Exception as ssl_error: + # If SSL fails, retry without SSL (if it was auto-detected) + if self.conn_info.get("ssl_input") is None: + logger.warning( + "SSL connection failed (%s). Falling back to non-SSL connection. " + "To require SSL, set use_tls=True explicitly.", + ssl_error, + ) + self._conn = self.adapter.connect( + host=self.conn_info["host"], + port=self.conn_info["port"], + user=self.conn_info["user"], + password=self.conn_info["passwd"], + init_command=self.init_fun, + charset=config["connection.charset"], + use_tls=False, # Explicitly disable SSL for fallback + ) + else: + raise + self._is_closed = False # Mark as connected after successful connection def set_query_cache(self, query_cache: str | None = None) -> None: """ @@ -285,7 +270,9 @@ def purge_query_cache(self) -> None: def close(self) -> None: """Close the database connection.""" - self._conn.close() + if self._conn is not None: + self._conn.close() + self._is_closed = True def __enter__(self) -> "Connection": """ @@ -347,7 +334,7 @@ def ping(self) -> None: Exception If the connection is closed. """ - self._conn.ping(reconnect=False) + self.adapter.ping(self._conn) @property def is_connected(self) -> bool: @@ -359,22 +346,24 @@ def is_connected(self) -> bool: bool True if connected. """ + if self._is_closed: + return False try: self.ping() except: + self._is_closed = True return False return True - @staticmethod - def _execute_query(cursor, query, args, suppress_warnings): + def _execute_query(self, cursor, query, args, suppress_warnings): try: with warnings.catch_warnings(): if suppress_warnings: # suppress all warnings arising from underlying SQL library warnings.simplefilter("ignore") cursor.execute(query, args) - except client.err.Error as err: - raise translate_query_error(err, query) + except Exception as err: + raise translate_query_error(err, query, self.adapter) def query( self, @@ -418,7 +407,8 @@ def query( if use_query_cache: if not config[cache_key]: raise errors.DataJointError(f"Provide filepath dj.config['{cache_key}'] when using query caching.") - hash_ = hashlib.md5((str(self._query_cache) + re.sub(r"`\$\w+`", "", query)).encode() + pack(args)).hexdigest() + # Cache key is backend-specific (no identifier normalization needed) + hash_ = hashlib.md5((str(self._query_cache)).encode() + pack(args) + query.encode()).hexdigest() cache_path = pathlib.Path(config[cache_key]) / str(hash_) try: buffer = cache_path.read_bytes() @@ -430,20 +420,19 @@ def query( if reconnect is None: reconnect = config["database.reconnect"] logger.debug("Executing SQL:" + query[:query_log_max_length]) - cursor_class = client.cursors.DictCursor if as_dict else client.cursors.Cursor - cursor = self._conn.cursor(cursor=cursor_class) + cursor = self.adapter.get_cursor(self._conn, as_dict=as_dict) try: self._execute_query(cursor, query, args, suppress_warnings) except errors.LostConnectionError: if not reconnect: raise - logger.warning("Reconnecting to MySQL server.") + logger.warning("Reconnecting to database server.") self.connect() if self._in_transaction: self.cancel_transaction() raise errors.LostConnectionError("Connection was lost during a transaction.") logger.debug("Re-executing") - cursor = self._conn.cursor(cursor=cursor_class) + cursor = self.adapter.get_cursor(self._conn, as_dict=as_dict) self._execute_query(cursor, query, args, suppress_warnings) if use_query_cache: @@ -462,7 +451,7 @@ def get_user(self) -> str: str User name and host as ``'user@host'``. """ - return self.query("SELECT user()").fetchone()[0] + return self.query(f"SELECT {self.adapter.current_user_expr()}").fetchone()[0] # ---------- transaction processing @property @@ -489,19 +478,19 @@ def start_transaction(self) -> None: """ if self.in_transaction: raise errors.DataJointError("Nested connections are not supported.") - self.query("START TRANSACTION WITH CONSISTENT SNAPSHOT") + self.query(self.adapter.start_transaction_sql()) self._in_transaction = True logger.debug("Transaction started") def cancel_transaction(self) -> None: """Cancel the current transaction and roll back all changes.""" - self.query("ROLLBACK") + self.query(self.adapter.rollback_sql()) self._in_transaction = False logger.debug("Transaction cancelled. Rolling back ...") def commit_transaction(self) -> None: """Commit all changes and close the transaction.""" - self.query("COMMIT") + self.query(self.adapter.commit_sql()) self._in_transaction = False logger.debug("Transaction committed and closed.") diff --git a/src/datajoint/declare.py b/src/datajoint/declare.py index a640c444f..375daa07e 100644 --- a/src/datajoint/declare.py +++ b/src/datajoint/declare.py @@ -31,8 +31,8 @@ "bool": (r"bool$", "tinyint"), # UUID (stored as binary) "uuid": (r"uuid$", "binary(16)"), - # JSON - "json": (r"json$", None), # json passes through as-is + # JSON (matches both json and jsonb for PostgreSQL compatibility) + "json": (r"jsonb?$", None), # json/jsonb passes through as-is # Binary (bytes maps to longblob in MySQL, bytea in PostgreSQL) "bytes": (r"bytes$", "longblob"), # Temporal @@ -190,6 +190,7 @@ def compile_foreign_key( attr_sql: list[str], foreign_key_sql: list[str], index_sql: list[str], + adapter, fk_attribute_map: dict[str, tuple[str, str]] | None = None, ) -> None: """ @@ -212,6 +213,8 @@ def compile_foreign_key( SQL FOREIGN KEY constraints. Updated in place. index_sql : list[str] SQL INDEX declarations. Updated in place. + adapter : DatabaseAdapter + Database adapter for backend-specific SQL generation. fk_attribute_map : dict, optional Mapping of ``child_attr -> (parent_table, parent_attr)``. Updated in place. @@ -261,30 +264,57 @@ def compile_foreign_key( attributes.append(attr) if primary_key is not None: primary_key.append(attr) - attr_sql.append(ref.heading[attr].sql.replace("NOT NULL ", "", int(is_nullable))) + + # Build foreign key column definition using adapter + parent_attr = ref.heading[attr] + sql_type = parent_attr.sql_type + # For PostgreSQL enum types, qualify with schema name + # Enum type names start with "enum_" (generated hash-based names) + if sql_type.startswith("enum_") and adapter.backend == "postgresql": + sql_type = f"{adapter.quote_identifier(ref.database)}.{adapter.quote_identifier(sql_type)}" + col_def = adapter.format_column_definition( + name=attr, + sql_type=sql_type, + nullable=is_nullable, + default=None, + comment=parent_attr.sql_comment, + ) + attr_sql.append(col_def) + # Track FK attribute mapping for lineage: child_attr -> (parent_table, parent_attr) if fk_attribute_map is not None: parent_table = ref.support[0] # e.g., `schema`.`table` parent_attr = ref.heading[attr].original_name fk_attribute_map[attr] = (parent_table, parent_attr) - # declare the foreign key + # declare the foreign key using adapter for identifier quoting + fk_cols = ", ".join(adapter.quote_identifier(col) for col in ref.primary_key) + pk_cols = ", ".join(adapter.quote_identifier(ref.heading[name].original_name) for name in ref.primary_key) + + # Build referenced table name with proper quoting + # ref.support[0] may have cached quoting from a different backend + # Extract database and table name and rebuild with current adapter + parent_full_name = ref.support[0] + # Try to parse as database.table (with or without quotes) + parts = parent_full_name.replace('"', "").replace("`", "").split(".") + if len(parts) == 2: + ref_table_name = f"{adapter.quote_identifier(parts[0])}.{adapter.quote_identifier(parts[1])}" + else: + ref_table_name = adapter.quote_identifier(parts[0]) + foreign_key_sql.append( - "FOREIGN KEY (`{fk}`) REFERENCES {ref} (`{pk}`) ON UPDATE CASCADE ON DELETE RESTRICT".format( - fk="`,`".join(ref.primary_key), - pk="`,`".join(ref.heading[name].original_name for name in ref.primary_key), - ref=ref.support[0], - ) + f"FOREIGN KEY ({fk_cols}) REFERENCES {ref_table_name} ({pk_cols}) ON UPDATE CASCADE ON DELETE RESTRICT" ) # declare unique index if is_unique: - index_sql.append("UNIQUE INDEX ({attrs})".format(attrs=",".join("`%s`" % attr for attr in ref.primary_key))) + index_cols = ", ".join(adapter.quote_identifier(attr) for attr in ref.primary_key) + index_sql.append(f"UNIQUE INDEX ({index_cols})") def prepare_declare( - definition: str, context: dict -) -> tuple[str, list[str], list[str], list[str], list[str], list[str], dict[str, tuple[str, str]]]: + definition: str, context: dict, adapter +) -> tuple[str, list[str], list[str], list[str], list[str], list[str], dict[str, tuple[str, str]], dict[str, str]]: """ Parse a table definition into its components. @@ -294,11 +324,13 @@ def prepare_declare( DataJoint table definition string. context : dict Namespace for resolving foreign key references. + adapter : DatabaseAdapter + Database adapter for backend-specific SQL generation. Returns ------- tuple - Seven-element tuple containing: + Eight-element tuple containing: - table_comment : str - primary_key : list[str] @@ -307,6 +339,7 @@ def prepare_declare( - index_sql : list[str] - external_stores : list[str] - fk_attribute_map : dict[str, tuple[str, str]] + - column_comments : dict[str, str] - Column name to comment mapping """ # split definition into lines definition = re.split(r"\s*\n\s*", definition.strip()) @@ -322,11 +355,12 @@ def prepare_declare( index_sql = [] external_stores = [] fk_attribute_map = {} # child_attr -> (parent_table, parent_attr) + column_comments = {} # column_name -> comment (for PostgreSQL COMMENT ON) for line in definition: if not line or line.startswith("#"): # ignore additional comments pass - elif line.startswith("---") or line.startswith("___"): + elif line.startswith("---"): in_key = False # start parsing dependent attributes elif is_foreign_key(line): compile_foreign_key( @@ -337,12 +371,13 @@ def prepare_declare( attribute_sql, foreign_key_sql, index_sql, + adapter, fk_attribute_map, ) elif re.match(r"^(unique\s+)?index\s*.*$", line, re.I): # index - compile_index(line, index_sql) + compile_index(line, index_sql, adapter) else: - name, sql, store = compile_attribute(line, in_key, foreign_key_sql, context) + name, sql, store, comment = compile_attribute(line, in_key, foreign_key_sql, context, adapter) if store: external_stores.append(store) if in_key and name not in primary_key: @@ -350,6 +385,8 @@ def prepare_declare( if name not in attributes: attributes.append(name) attribute_sql.append(sql) + if comment: + column_comments[name] = comment return ( table_comment, @@ -359,40 +396,55 @@ def prepare_declare( index_sql, external_stores, fk_attribute_map, + column_comments, ) def declare( - full_table_name: str, definition: str, context: dict -) -> tuple[str, list[str], list[str], dict[str, tuple[str, str]]]: + full_table_name: str, definition: str, context: dict, adapter +) -> tuple[str, list[str], list[str], dict[str, tuple[str, str]], list[str], list[str]]: r""" Parse a definition and generate SQL CREATE TABLE statement. Parameters ---------- full_table_name : str - Fully qualified table name (e.g., ```\`schema\`.\`table\```). + Fully qualified table name (e.g., ```\`schema\`.\`table\``` or ```"schema"."table"```). definition : str DataJoint table definition string. context : dict Namespace for resolving foreign key references. + adapter : DatabaseAdapter + Database adapter for backend-specific SQL generation. Returns ------- tuple - Four-element tuple: + Six-element tuple: - sql : str - SQL CREATE TABLE statement - external_stores : list[str] - External store names used - primary_key : list[str] - Primary key attribute names - fk_attribute_map : dict - FK attribute lineage mapping + - pre_ddl : list[str] - DDL statements to run BEFORE CREATE TABLE (e.g., CREATE TYPE) + - post_ddl : list[str] - DDL statements to run AFTER CREATE TABLE (e.g., COMMENT ON) Raises ------ DataJointError If table name exceeds max length or has no primary key. """ - table_name = full_table_name.strip("`").split(".")[1] + # Parse table name without assuming quote character + # Extract schema.table from quoted name using adapter + quote_char = adapter.quote_identifier("x")[0] # Get quote char from adapter + parts = full_table_name.split(".") + if len(parts) == 2: + schema_name = parts[0].strip(quote_char) + table_name = parts[1].strip(quote_char) + else: + schema_name = None + table_name = parts[0].strip(quote_char) + if len(table_name) > MAX_TABLE_NAME_LENGTH: raise DataJointError( "Table name `{name}` exceeds the max length of {max_length}".format( @@ -408,35 +460,87 @@ def declare( index_sql, external_stores, fk_attribute_map, - ) = prepare_declare(definition, context) + column_comments, + ) = prepare_declare(definition, context, adapter) # Add hidden job metadata for Computed/Imported tables (not parts) - # Note: table_name may still have backticks, strip them for prefix checking - clean_table_name = table_name.strip("`") if config.jobs.add_job_metadata: # Check if this is a Computed (__) or Imported (_) table, but not a Part (contains __ in middle) - is_computed = clean_table_name.startswith("__") and "__" not in clean_table_name[2:] - is_imported = clean_table_name.startswith("_") and not clean_table_name.startswith("__") + is_computed = table_name.startswith("__") and "__" not in table_name[2:] + is_imported = table_name.startswith("_") and not table_name.startswith("__") if is_computed or is_imported: - job_metadata_sql = [ - "`_job_start_time` datetime(3) DEFAULT NULL", - "`_job_duration` float DEFAULT NULL", - "`_job_version` varchar(64) DEFAULT ''", - ] + job_metadata_sql = adapter.job_metadata_columns() attribute_sql.extend(job_metadata_sql) if not primary_key: - raise DataJointError("Table must have a primary key") + # Singleton table: add hidden sentinel attribute + primary_key = ["_singleton"] + singleton_comment = ":bool:singleton primary key" + sql_type = adapter.core_type_to_sql("bool") + singleton_sql = adapter.format_column_definition( + name="_singleton", + sql_type=sql_type, + nullable=False, + default="NOT NULL DEFAULT TRUE", + comment=singleton_comment, + ) + attribute_sql.insert(0, singleton_sql) + column_comments["_singleton"] = singleton_comment + + pre_ddl = [] # DDL to run BEFORE CREATE TABLE (e.g., CREATE TYPE for enums) + post_ddl = [] # DDL to run AFTER CREATE TABLE (e.g., COMMENT ON) + # Get pending enum type DDL for PostgreSQL (must run before CREATE TABLE) + if schema_name and hasattr(adapter, "get_pending_enum_ddl"): + pre_ddl.extend(adapter.get_pending_enum_ddl(schema_name)) + + # Build PRIMARY KEY clause using adapter + pk_cols = ", ".join(adapter.quote_identifier(pk) for pk in primary_key) + pk_clause = f"PRIMARY KEY ({pk_cols})" + + # Handle indexes - inline for MySQL, separate CREATE INDEX for PostgreSQL + if adapter.supports_inline_indexes: + # MySQL: include indexes in CREATE TABLE + create_table_indexes = index_sql + else: + # PostgreSQL: convert to CREATE INDEX statements for post_ddl + create_table_indexes = [] + for idx_def in index_sql: + # Parse index definition: "unique index (cols)" or "index (cols)" + idx_match = re.match(r"(unique\s+)?index\s*\(([^)]+)\)", idx_def, re.I) + if idx_match: + is_unique = idx_match.group(1) is not None + # Extract column names (may be quoted or have expressions) + cols_str = idx_match.group(2) + # Simple split on comma - columns are already quoted + columns = [c.strip().strip('`"') for c in cols_str.split(",")] + # Generate CREATE INDEX DDL + create_idx_ddl = adapter.create_index_ddl(full_table_name, columns, unique=is_unique) + post_ddl.append(create_idx_ddl) + + # Assemble CREATE TABLE sql = ( - "CREATE TABLE IF NOT EXISTS %s (\n" % full_table_name - + ",\n".join(attribute_sql + ["PRIMARY KEY (`" + "`,`".join(primary_key) + "`)"] + foreign_key_sql + index_sql) - + '\n) ENGINE=InnoDB, COMMENT "%s"' % table_comment + f"CREATE TABLE IF NOT EXISTS {full_table_name} (\n" + + ",\n".join(attribute_sql + [pk_clause] + foreign_key_sql + create_table_indexes) + + f"\n) {adapter.table_options_clause(table_comment)}" ) - return sql, external_stores, primary_key, fk_attribute_map + + # Add table-level comment DDL if needed (PostgreSQL) + table_comment_ddl = adapter.table_comment_ddl(full_table_name, table_comment) + if table_comment_ddl: + post_ddl.append(table_comment_ddl) + + # Add column-level comments DDL if needed (PostgreSQL) + # Column comments contain type specifications like ::user_comment + for col_name, comment in column_comments.items(): + col_comment_ddl = adapter.column_comment_ddl(full_table_name, col_name, comment) + if col_comment_ddl: + post_ddl.append(col_comment_ddl) + + return sql, external_stores, primary_key, fk_attribute_map, pre_ddl, post_ddl -def _make_attribute_alter(new: list[str], old: list[str], primary_key: list[str]) -> list[str]: +def _make_attribute_alter(new: list[str], old: list[str], primary_key: list[str], adapter) -> list[str]: """ Generate SQL ALTER commands for attribute changes. @@ -448,6 +552,8 @@ def _make_attribute_alter(new: list[str], old: list[str], primary_key: list[str] Old attribute SQL declarations. primary_key : list[str] Primary key attribute names (cannot be altered). + adapter : DatabaseAdapter + Database adapter for backend-specific SQL generation. Returns ------- @@ -459,8 +565,9 @@ def _make_attribute_alter(new: list[str], old: list[str], primary_key: list[str] DataJointError If an attribute is renamed twice or renamed from non-existent attribute. """ - # parse attribute names - name_regexp = re.compile(r"^`(?P\w+)`") + # parse attribute names - use adapter's quote character + quote_char = re.escape(adapter.quote_identifier("x")[0]) + name_regexp = re.compile(rf"^{quote_char}(?P\w+){quote_char}") original_regexp = re.compile(r'COMMENT "{\s*(?P\w+)\s*}') matched = ((name_regexp.match(d), original_regexp.search(d)) for d in new) new_names = dict((d.group("name"), n and n.group("name")) for d, n in matched) @@ -486,7 +593,7 @@ def _make_attribute_alter(new: list[str], old: list[str], primary_key: list[str] # dropping attributes to_drop = [n for n in old_names if n not in renamed and n not in new_names] - sql = ["DROP `%s`" % n for n in to_drop] + sql = [f"DROP {adapter.quote_identifier(n)}" for n in to_drop] old_names = [name for name in old_names if name not in to_drop] # add or change attributes in order @@ -503,25 +610,24 @@ def _make_attribute_alter(new: list[str], old: list[str], primary_key: list[str] if idx >= 1 and old_names[idx - 1] != (prev[1] or prev[0]): after = prev[0] if new_def not in old or after: - sql.append( - "{command} {new_def} {after}".format( - command=( - "ADD" - if (old_name or new_name) not in old_names - else "MODIFY" - if not old_name - else "CHANGE `%s`" % old_name - ), - new_def=new_def, - after="" if after is None else "AFTER `%s`" % after, - ) - ) + # Determine command type + if (old_name or new_name) not in old_names: + command = "ADD" + elif not old_name: + command = "MODIFY" + else: + command = f"CHANGE {adapter.quote_identifier(old_name)}" + + # Build after clause + after_clause = "" if after is None else f"AFTER {adapter.quote_identifier(after)}" + + sql.append(f"{command} {new_def} {after_clause}") prev = new_name, old_name return sql -def alter(definition: str, old_definition: str, context: dict) -> tuple[list[str], list[str]]: +def alter(definition: str, old_definition: str, context: dict, adapter) -> tuple[list[str], list[str]]: """ Generate SQL ALTER commands for table definition changes. @@ -533,6 +639,8 @@ def alter(definition: str, old_definition: str, context: dict) -> tuple[list[str Current table definition. context : dict Namespace for resolving foreign key references. + adapter : DatabaseAdapter + Database adapter for backend-specific SQL generation. Returns ------- @@ -555,7 +663,8 @@ def alter(definition: str, old_definition: str, context: dict) -> tuple[list[str index_sql, external_stores, _fk_attribute_map, - ) = prepare_declare(definition, context) + _column_comments, + ) = prepare_declare(definition, context, adapter) ( table_comment_, primary_key_, @@ -564,7 +673,8 @@ def alter(definition: str, old_definition: str, context: dict) -> tuple[list[str index_sql_, external_stores_, _fk_attribute_map_, - ) = prepare_declare(old_definition, context) + _column_comments_, + ) = prepare_declare(old_definition, context, adapter) # analyze differences between declarations sql = list() @@ -575,9 +685,12 @@ def alter(definition: str, old_definition: str, context: dict) -> tuple[list[str if index_sql != index_sql_: raise NotImplementedError("table.alter cannot alter indexes (yet)") if attribute_sql != attribute_sql_: - sql.extend(_make_attribute_alter(attribute_sql, attribute_sql_, primary_key)) + sql.extend(_make_attribute_alter(attribute_sql, attribute_sql_, primary_key, adapter)) if table_comment != table_comment_: - sql.append('COMMENT="%s"' % table_comment) + # For MySQL: COMMENT="new comment" + # For PostgreSQL: would need COMMENT ON TABLE, but that's not an ALTER TABLE clause + # Keep MySQL syntax for now (ALTER TABLE ... COMMENT="...") + sql.append(f'COMMENT="{table_comment}"') return sql, [e for e in external_stores if e not in external_stores_] @@ -620,7 +733,7 @@ def _parse_index_args(args: str) -> list[str]: return [arg for arg in result if arg] # Filter empty strings -def compile_index(line: str, index_sql: list[str]) -> None: +def compile_index(line: str, index_sql: list[str], adapter) -> None: """ Parse an index declaration and append SQL to index_sql. @@ -631,6 +744,8 @@ def compile_index(line: str, index_sql: list[str]) -> None: ``"unique index(attr)"``). index_sql : list[str] List of index SQL declarations. Updated in place. + adapter : DatabaseAdapter + Database adapter for backend-specific SQL generation. Raises ------ @@ -639,11 +754,11 @@ def compile_index(line: str, index_sql: list[str]) -> None: """ def format_attribute(attr): - match, attr = translate_attribute(attr) + match, attr = translate_attribute(attr, adapter) if match is None: return attr if match["path"] is None: - return f"`{attr}`" + return adapter.quote_identifier(attr) return f"({attr})" match = re.match(r"(?Punique\s+)?index\s*\(\s*(?P.*)\)", line, re.I) @@ -660,7 +775,7 @@ def format_attribute(attr): ) -def substitute_special_type(match: dict, category: str, foreign_key_sql: list[str], context: dict) -> None: +def substitute_special_type(match: dict, category: str, foreign_key_sql: list[str], context: dict, adapter) -> None: """ Substitute special types with their native SQL equivalents. @@ -679,6 +794,8 @@ def substitute_special_type(match: dict, category: str, foreign_key_sql: list[st Foreign key declarations (unused, kept for API compatibility). context : dict Namespace for codec lookup (unused, kept for API compatibility). + adapter : DatabaseAdapter + Database adapter for backend-specific type mapping. """ if category == "CODEC": # Codec - resolve to underlying dtype @@ -699,11 +816,11 @@ def substitute_special_type(match: dict, category: str, foreign_key_sql: list[st # Recursively resolve if dtype is also a special type category = match_type(match["type"]) if category in SPECIAL_TYPES: - substitute_special_type(match, category, foreign_key_sql, context) + substitute_special_type(match, category, foreign_key_sql, context, adapter) elif category in CORE_TYPE_NAMES: - # Core DataJoint type - substitute with native SQL type if mapping exists - core_name = category.lower() - sql_type = CORE_TYPE_SQL.get(core_name) + # Core DataJoint type - substitute with native SQL type using adapter + # Pass the full type string (e.g., "varchar(255)") not just category name + sql_type = adapter.core_type_to_sql(match["type"]) if sql_type is not None: match["type"] = sql_type # else: type passes through as-is (json, date, datetime, char, varchar, enum) @@ -711,7 +828,9 @@ def substitute_special_type(match: dict, category: str, foreign_key_sql: list[st raise DataJointError(f"Unknown special type: {category}") -def compile_attribute(line: str, in_key: bool, foreign_key_sql: list[str], context: dict) -> tuple[str, str, str | None]: +def compile_attribute( + line: str, in_key: bool, foreign_key_sql: list[str], context: dict, adapter +) -> tuple[str, str, str | None, str | None]: """ Convert an attribute definition from DataJoint format to SQL. @@ -725,15 +844,18 @@ def compile_attribute(line: str, in_key: bool, foreign_key_sql: list[str], conte Foreign key declarations (passed to type substitution). context : dict Namespace for codec lookup. + adapter : DatabaseAdapter + Database adapter for backend-specific SQL generation. Returns ------- tuple - Three-element tuple: + Four-element tuple: - name : str - Attribute name - sql : str - SQL column declaration - store : str or None - External store name if applicable + - comment : str or None - Column comment (for PostgreSQL COMMENT ON) Raises ------ @@ -760,8 +882,22 @@ def compile_attribute(line: str, in_key: bool, foreign_key_sql: list[str], conte match["default"] = "DEFAULT NULL" # nullable attributes default to null else: if match["default"]: - quote = match["default"].split("(")[0].upper() not in CONSTANT_LITERALS and match["default"][0] not in "\"'" - match["default"] = "NOT NULL DEFAULT " + ('"%s"' if quote else "%s") % match["default"] + default_val = match["default"] + base_val = default_val.split("(")[0].upper() + + if base_val in CONSTANT_LITERALS: + # SQL constants like NULL, CURRENT_TIMESTAMP - use as-is + match["default"] = f"NOT NULL DEFAULT {default_val}" + elif default_val.startswith('"') and default_val.endswith('"'): + # Double-quoted string - convert to single quotes for PostgreSQL + inner = default_val[1:-1].replace("'", "''") # Escape single quotes + match["default"] = f"NOT NULL DEFAULT '{inner}'" + elif default_val.startswith("'"): + # Already single-quoted - use as-is + match["default"] = f"NOT NULL DEFAULT {default_val}" + else: + # Unquoted value - wrap in single quotes + match["default"] = f"NOT NULL DEFAULT '{default_val}'" else: match["default"] = "NOT NULL" @@ -775,7 +911,7 @@ def compile_attribute(line: str, in_key: bool, foreign_key_sql: list[str], conte if category in SPECIAL_TYPES: # Core types and Codecs are recorded in comment for reconstruction match["comment"] = ":{type}:{comment}".format(**match) - substitute_special_type(match, category, foreign_key_sql, context) + substitute_special_type(match, category, foreign_key_sql, context, adapter) elif category in NATIVE_TYPES: # Native type - warn user logger.warning( @@ -789,5 +925,12 @@ def compile_attribute(line: str, in_key: bool, foreign_key_sql: list[str], conte if ("blob" in final_type) and match["default"] not in {"DEFAULT NULL", "NOT NULL"}: raise DataJointError("The default value for blob attributes can only be NULL in:\n{line}".format(line=line)) - sql = ("`{name}` {type} {default}" + (' COMMENT "{comment}"' if match["comment"] else "")).format(**match) - return match["name"], sql, match.get("store") + # Use adapter to format column definition + sql = adapter.format_column_definition( + name=match["name"], + sql_type=match["type"], + nullable=match["nullable"], + default=match["default"] if match["default"] else None, + comment=match["comment"] if match["comment"] else None, + ) + return match["name"], sql, match.get("store"), match["comment"] if match["comment"] else None diff --git a/src/datajoint/dependencies.py b/src/datajoint/dependencies.py index 621011426..83162a112 100644 --- a/src/datajoint/dependencies.py +++ b/src/datajoint/dependencies.py @@ -31,8 +31,14 @@ def extract_master(part_table: str) -> str | None: str or None Master table name if part_table is a part table, None otherwise. """ - match = re.match(r"(?P`\w+`.`#?\w+)__\w+`", part_table) - return match["master"] + "`" if match else None + # Match both MySQL backticks and PostgreSQL double quotes + # MySQL: `schema`.`master__part` + # PostgreSQL: "schema"."master__part" + match = re.match(r'(?P(?P[`"])[\w]+(?P=q)\.(?P=q)#?[\w]+)__[\w]+(?P=q)', part_table) + if match: + q = match["q"] + return match["master"] + q + return None def topo_sort(graph: nx.DiGraph) -> list[str]: @@ -131,6 +137,7 @@ def __init__(self, connection=None) -> None: def clear(self) -> None: """Clear the graph and reset loaded state.""" self._loaded = False + self._node_alias_count = itertools.count() # reset alias IDs for consistency super().clear() def load(self, force: bool = True) -> None: @@ -151,39 +158,105 @@ def load(self, force: bool = True) -> None: self.clear() - # load primary key info - keys = self._conn.query( - """ - SELECT - concat('`', table_schema, '`.`', table_name, '`') as tab, column_name + # Get adapter for backend-specific SQL generation + adapter = self._conn.adapter + + # Build schema list for IN clause + schemas_list = ", ".join(adapter.quote_string(s) for s in self._conn.schemas) + + # Backend-specific queries for primary keys and foreign keys + # Note: Both PyMySQL and psycopg2 use %s placeholders, so escape % as %% + like_pattern = "'~%%'" + + if adapter.backend == "mysql": + # MySQL: use concat() and MySQL-specific information_schema columns + tab_expr = "concat('`', table_schema, '`.`', table_name, '`')" + + # load primary key info (MySQL uses constraint_name='PRIMARY') + keys = self._conn.query( + f""" + SELECT {tab_expr} as tab, column_name FROM information_schema.key_column_usage - WHERE table_name not LIKE "~%%" AND table_schema in ('{schemas}') AND constraint_name="PRIMARY" - """.format(schemas="','".join(self._conn.schemas)) - ) - pks = defaultdict(set) - for key in keys: - pks[key[0]].add(key[1]) + WHERE table_name NOT LIKE {like_pattern} + AND table_schema in ({schemas_list}) + AND constraint_name='PRIMARY' + """ + ) + pks = defaultdict(set) + for key in keys: + pks[key[0]].add(key[1]) + + # load foreign keys (MySQL has referenced_* columns) + ref_tab_expr = "concat('`', referenced_table_schema, '`.`', referenced_table_name, '`')" + fk_keys = self._conn.query( + f""" + SELECT constraint_name, + {tab_expr} as referencing_table, + {ref_tab_expr} as referenced_table, + column_name, referenced_column_name + FROM information_schema.key_column_usage + WHERE referenced_table_name NOT LIKE {like_pattern} + AND (referenced_table_schema in ({schemas_list}) + OR referenced_table_schema is not NULL AND table_schema in ({schemas_list})) + """, + as_dict=True, + ) + else: + # PostgreSQL: use || concatenation and different query structure + tab_expr = "'\"' || kcu.table_schema || '\".\"' || kcu.table_name || '\"'" + + # load primary key info (PostgreSQL uses constraint_type='PRIMARY KEY') + keys = self._conn.query( + f""" + SELECT {tab_expr} as tab, kcu.column_name + FROM information_schema.key_column_usage kcu + JOIN information_schema.table_constraints tc + ON kcu.constraint_name = tc.constraint_name + AND kcu.table_schema = tc.table_schema + WHERE kcu.table_name NOT LIKE {like_pattern} + AND kcu.table_schema in ({schemas_list}) + AND tc.constraint_type = 'PRIMARY KEY' + """ + ) + pks = defaultdict(set) + for key in keys: + pks[key[0]].add(key[1]) + + # load foreign keys using pg_constraint system catalogs + # The information_schema approach creates a Cartesian product for composite FKs + # because constraint_column_usage doesn't have ordinal_position. + # Using pg_constraint with unnest(conkey, confkey) WITH ORDINALITY gives correct mapping. + fk_keys = self._conn.query( + f""" + SELECT + c.conname as constraint_name, + '"' || ns1.nspname || '"."' || cl1.relname || '"' as referencing_table, + '"' || ns2.nspname || '"."' || cl2.relname || '"' as referenced_table, + a1.attname as column_name, + a2.attname as referenced_column_name + FROM pg_constraint c + JOIN pg_class cl1 ON c.conrelid = cl1.oid + JOIN pg_namespace ns1 ON cl1.relnamespace = ns1.oid + JOIN pg_class cl2 ON c.confrelid = cl2.oid + JOIN pg_namespace ns2 ON cl2.relnamespace = ns2.oid + CROSS JOIN LATERAL unnest(c.conkey, c.confkey) WITH ORDINALITY AS cols(conkey, confkey, ord) + JOIN pg_attribute a1 ON a1.attrelid = cl1.oid AND a1.attnum = cols.conkey + JOIN pg_attribute a2 ON a2.attrelid = cl2.oid AND a2.attnum = cols.confkey + WHERE c.contype = 'f' + AND cl1.relname NOT LIKE {like_pattern} + AND (ns2.nspname in ({schemas_list}) + OR ns1.nspname in ({schemas_list})) + ORDER BY c.conname, cols.ord + """, + as_dict=True, + ) # add nodes to the graph for n, pk in pks.items(): self.add_node(n, primary_key=pk) - # load foreign keys - keys = ( - {k.lower(): v for k, v in elem.items()} - for elem in self._conn.query( - """ - SELECT constraint_name, - concat('`', table_schema, '`.`', table_name, '`') as referencing_table, - concat('`', referenced_table_schema, '`.`', referenced_table_name, '`') as referenced_table, - column_name, referenced_column_name - FROM information_schema.key_column_usage - WHERE referenced_table_name NOT LIKE "~%%" AND (referenced_table_schema in ('{schemas}') OR - referenced_table_schema is not NULL AND table_schema in ('{schemas}')) - """.format(schemas="','".join(self._conn.schemas)), - as_dict=True, - ) - ) + # Process foreign keys (same for both backends) + keys = ({k.lower(): v for k, v in elem.items()} for elem in fk_keys) fks = defaultdict(lambda: dict(attr_map=dict())) for key in keys: d = fks[ diff --git a/src/datajoint/diagram.py b/src/datajoint/diagram.py index c52340f46..48e18fd0d 100644 --- a/src/datajoint/diagram.py +++ b/src/datajoint/diagram.py @@ -16,6 +16,7 @@ from .dependencies import topo_sort from .errors import DataJointError +from .settings import config from .table import Table, lookup_class_name from .user_tables import Computed, Imported, Lookup, Manual, Part, _AliasNode, _get_tier @@ -90,12 +91,19 @@ class Diagram(nx.DiGraph): ----- ``diagram + 1 - 1`` may differ from ``diagram - 1 + 1``. Only tables loaded in the connection are displayed. + + Layout direction is controlled via ``dj.config.display.diagram_direction`` + (default ``"TB"``). Use ``dj.config.override()`` to change temporarily:: + + with dj.config.override(display_diagram_direction="LR"): + dj.Diagram(schema).draw() """ def __init__(self, source, context=None) -> None: if isinstance(source, Diagram): # copy constructor self.nodes_to_show = set(source.nodes_to_show) + self._expanded_nodes = set(source._expanded_nodes) self.context = source.context super().__init__(source) return @@ -134,8 +142,11 @@ def __init__(self, source, context=None) -> None: except AttributeError: raise DataJointError("Cannot plot Diagram for %s" % repr(source)) for node in self: - if node.startswith("`%s`" % database): + # Handle both MySQL backticks and PostgreSQL double quotes + if node.startswith("`%s`" % database) or node.startswith('"%s"' % database): self.nodes_to_show.add(node) + # All nodes start as expanded + self._expanded_nodes = set(self.nodes_to_show) @classmethod def from_sequence(cls, sequence) -> "Diagram": @@ -173,6 +184,34 @@ def is_part(part, master): self.nodes_to_show.update(n for n in self.nodes() if any(is_part(n, m) for m in self.nodes_to_show)) return self + def collapse(self) -> "Diagram": + """ + Mark all nodes in this diagram as collapsed. + + Collapsed nodes are shown as a single node per schema. When combined + with other diagrams using ``+``, expanded nodes win: if a node is + expanded in either operand, it remains expanded in the result. + + Returns + ------- + Diagram + A copy of this diagram with all nodes collapsed. + + Examples + -------- + >>> # Show schema1 expanded, schema2 collapsed into single nodes + >>> dj.Diagram(schema1) + dj.Diagram(schema2).collapse() + + >>> # Collapse all three schemas together + >>> (dj.Diagram(schema1) + dj.Diagram(schema2) + dj.Diagram(schema3)).collapse() + + >>> # Expand one table from collapsed schema + >>> dj.Diagram(schema).collapse() + dj.Diagram(SingleTable) + """ + result = Diagram(self) + result._expanded_nodes = set() # All nodes collapsed + return result + def __add__(self, arg) -> "Diagram": """ Union or downstream expansion. @@ -187,21 +226,31 @@ def __add__(self, arg) -> "Diagram": Diagram Combined or expanded diagram. """ - self = Diagram(self) # copy + result = Diagram(self) # copy try: - self.nodes_to_show.update(arg.nodes_to_show) + # Merge nodes and edges from the other diagram + result.add_nodes_from(arg.nodes(data=True)) + result.add_edges_from(arg.edges(data=True)) + result.nodes_to_show.update(arg.nodes_to_show) + # Merge contexts for class name lookups + result.context = {**result.context, **arg.context} + # Expanded wins: union of expanded nodes from both operands + result._expanded_nodes = self._expanded_nodes | arg._expanded_nodes except AttributeError: try: - self.nodes_to_show.add(arg.full_table_name) + result.nodes_to_show.add(arg.full_table_name) + result._expanded_nodes.add(arg.full_table_name) except AttributeError: for i in range(arg): - new = nx.algorithms.boundary.node_boundary(self, self.nodes_to_show) + new = nx.algorithms.boundary.node_boundary(result, result.nodes_to_show) if not new: break # add nodes referenced by aliased nodes - new.update(nx.algorithms.boundary.node_boundary(self, (a for a in new if a.isdigit()))) - self.nodes_to_show.update(new) - return self + new.update(nx.algorithms.boundary.node_boundary(result, (a for a in new if a.isdigit()))) + result.nodes_to_show.update(new) + # New nodes from expansion are expanded + result._expanded_nodes = result._expanded_nodes | result.nodes_to_show + return result def __sub__(self, arg) -> "Diagram": """ @@ -274,7 +323,9 @@ def _make_graph(self) -> nx.DiGraph: """ # mark "distinguished" tables, i.e. those that introduce new primary key # attributes - for name in self.nodes_to_show: + # Filter nodes_to_show to only include nodes that exist in the graph + valid_nodes = self.nodes_to_show.intersection(set(self.nodes())) + for name in valid_nodes: foreign_attributes = set( attr for p in self.in_edges(name, data=True) for attr in p[2]["attr_map"] if p[2]["primary"] ) @@ -282,21 +333,210 @@ def _make_graph(self) -> nx.DiGraph: "primary_key" in self.nodes[name] and foreign_attributes < self.nodes[name]["primary_key"] ) # include aliased nodes that are sandwiched between two displayed nodes - gaps = set(nx.algorithms.boundary.node_boundary(self, self.nodes_to_show)).intersection( - nx.algorithms.boundary.node_boundary(nx.DiGraph(self).reverse(), self.nodes_to_show) + gaps = set(nx.algorithms.boundary.node_boundary(self, valid_nodes)).intersection( + nx.algorithms.boundary.node_boundary(nx.DiGraph(self).reverse(), valid_nodes) ) - nodes = self.nodes_to_show.union(a for a in gaps if a.isdigit) + nodes = valid_nodes.union(a for a in gaps if a.isdigit()) # construct subgraph and rename nodes to class names graph = nx.DiGraph(nx.DiGraph(self).subgraph(nodes)) nx.set_node_attributes(graph, name="node_type", values={n: _get_tier(n) for n in graph}) # relabel nodes to class names mapping = {node: lookup_class_name(node, self.context) or node for node in graph.nodes()} - new_names = [mapping.values()] + new_names = list(mapping.values()) if len(new_names) > len(set(new_names)): raise DataJointError("Some classes have identical names. The Diagram cannot be plotted.") nx.relabel_nodes(graph, mapping, copy=False) return graph + def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str]]: + """ + Apply collapse logic to the graph. + + Nodes in nodes_to_show but not in _expanded_nodes are collapsed into + single schema nodes. + + Parameters + ---------- + graph : nx.DiGraph + The graph from _make_graph(). + + Returns + ------- + tuple[nx.DiGraph, dict[str, str]] + Modified graph and mapping of collapsed schema labels to their table count. + """ + # Filter to valid nodes (those that exist in the underlying graph) + valid_nodes = self.nodes_to_show.intersection(set(self.nodes())) + valid_expanded = self._expanded_nodes.intersection(set(self.nodes())) + + # If all nodes are expanded, no collapse needed + if valid_expanded >= valid_nodes: + return graph, {} + + # Map full_table_names to class_names + full_to_class = {node: lookup_class_name(node, self.context) or node for node in valid_nodes} + class_to_full = {v: k for k, v in full_to_class.items()} + + # Identify expanded class names + expanded_class_names = {full_to_class.get(node, node) for node in valid_expanded} + + # Identify nodes to collapse (class names) + nodes_to_collapse = set(graph.nodes()) - expanded_class_names + + if not nodes_to_collapse: + return graph, {} + + # Group collapsed nodes by schema + collapsed_by_schema = {} # schema_name -> list of class_names + for class_name in nodes_to_collapse: + full_name = class_to_full.get(class_name) + if full_name: + parts = full_name.replace('"', "`").split("`") + if len(parts) >= 2: + schema_name = parts[1] + if schema_name not in collapsed_by_schema: + collapsed_by_schema[schema_name] = [] + collapsed_by_schema[schema_name].append(class_name) + + if not collapsed_by_schema: + return graph, {} + + # Determine labels for collapsed schemas + schema_modules = {} + for schema_name, class_names in collapsed_by_schema.items(): + schema_modules[schema_name] = set() + for class_name in class_names: + cls = self._resolve_class(class_name) + if cls is not None and hasattr(cls, "__module__"): + module_name = cls.__module__.split(".")[-1] + schema_modules[schema_name].add(module_name) + + # Collect module names for ALL schemas in the diagram (not just collapsed) + all_schema_modules = {} # schema_name -> module_name + for node in graph.nodes(): + full_name = class_to_full.get(node) + if full_name: + parts = full_name.replace('"', "`").split("`") + if len(parts) >= 2: + db_schema = parts[1] + cls = self._resolve_class(node) + if cls is not None and hasattr(cls, "__module__"): + module_name = cls.__module__.split(".")[-1] + all_schema_modules[db_schema] = module_name + + # Check which module names are shared by multiple schemas + module_to_schemas = {} + for db_schema, module_name in all_schema_modules.items(): + if module_name not in module_to_schemas: + module_to_schemas[module_name] = [] + module_to_schemas[module_name].append(db_schema) + + ambiguous_modules = {m for m, schemas in module_to_schemas.items() if len(schemas) > 1} + + # Determine labels for collapsed schemas + collapsed_labels = {} # schema_name -> label + for schema_name, modules in schema_modules.items(): + if len(modules) == 1: + module_name = next(iter(modules)) + # Use database schema name if module is ambiguous + if module_name in ambiguous_modules: + label = schema_name + else: + label = module_name + else: + label = schema_name + collapsed_labels[schema_name] = label + + # Build counts using final labels + collapsed_counts = {} # label -> count of tables + for schema_name, class_names in collapsed_by_schema.items(): + label = collapsed_labels[schema_name] + collapsed_counts[label] = len(class_names) + + # Create new graph with collapsed nodes + new_graph = nx.DiGraph() + + # Map old node names to new names (collapsed nodes -> schema label) + node_mapping = {} + for node in graph.nodes(): + full_name = class_to_full.get(node) + if full_name: + parts = full_name.replace('"', "`").split("`") + if len(parts) >= 2 and node in nodes_to_collapse: + schema_name = parts[1] + node_mapping[node] = collapsed_labels[schema_name] + else: + node_mapping[node] = node + else: + # Alias nodes - check if they should be collapsed + # An alias node should be collapsed if ALL its neighbors are collapsed + neighbors = set(graph.predecessors(node)) | set(graph.successors(node)) + if neighbors and neighbors <= nodes_to_collapse: + # Get schema from first neighbor + neighbor = next(iter(neighbors)) + full_name = class_to_full.get(neighbor) + if full_name: + parts = full_name.replace('"', "`").split("`") + if len(parts) >= 2: + schema_name = parts[1] + node_mapping[node] = collapsed_labels[schema_name] + continue + node_mapping[node] = node + + # Build reverse mapping: label -> schema_name + label_to_schema = {label: schema for schema, label in collapsed_labels.items()} + + # Add nodes + added_collapsed = set() + for old_node, new_node in node_mapping.items(): + if new_node in collapsed_counts: + # This is a collapsed schema node + if new_node not in added_collapsed: + schema_name = label_to_schema.get(new_node, new_node) + new_graph.add_node( + new_node, + node_type=None, + collapsed=True, + table_count=collapsed_counts[new_node], + schema_name=schema_name, + ) + added_collapsed.add(new_node) + else: + new_graph.add_node(new_node, **graph.nodes[old_node]) + + # Add edges (avoiding self-loops and duplicates) + for src, dest, data in graph.edges(data=True): + new_src = node_mapping[src] + new_dest = node_mapping[dest] + if new_src != new_dest and not new_graph.has_edge(new_src, new_dest): + new_graph.add_edge(new_src, new_dest, **data) + + return new_graph, collapsed_counts + + def _resolve_class(self, name: str): + """ + Safely resolve a table class from a dotted name without eval(). + + Parameters + ---------- + name : str + Dotted class name like "MyTable" or "Module.MyTable". + + Returns + ------- + type or None + The table class if found, otherwise None. + """ + parts = name.split(".") + obj = self.context.get(parts[0]) + for part in parts[1:]: + if obj is None: + return None + obj = getattr(obj, part, None) + if obj is not None and isinstance(obj, type) and issubclass(obj, Table): + return obj + return None + @staticmethod def _encapsulate_edge_attributes(graph: nx.DiGraph) -> None: """ @@ -330,8 +570,78 @@ def _encapsulate_node_names(graph: nx.DiGraph) -> None: ) def make_dot(self): + """ + Generate a pydot graph object. + + Returns + ------- + pydot.Dot + The graph object ready for rendering. + + Notes + ----- + Layout direction is controlled via ``dj.config.display.diagram_direction``. + Tables are grouped by schema, with the Python module name shown as the + group label when available. + """ + direction = config.display.diagram_direction graph = self._make_graph() - graph.nodes() + + # Apply collapse logic if needed + graph, collapsed_counts = self._apply_collapse(graph) + + # Build schema mapping: class_name -> schema_name + # Group by database schema, label with Python module name if 1:1 mapping + schema_map = {} # class_name -> schema_name + schema_modules = {} # schema_name -> set of module names + + for full_name in self.nodes_to_show: + # Extract schema from full table name like `schema`.`table` or "schema"."table" + parts = full_name.replace('"', "`").split("`") + if len(parts) >= 2: + schema_name = parts[1] # schema is between first pair of backticks + class_name = lookup_class_name(full_name, self.context) or full_name + schema_map[class_name] = schema_name + + # Collect all module names for this schema + if schema_name not in schema_modules: + schema_modules[schema_name] = set() + cls = self._resolve_class(class_name) + if cls is not None and hasattr(cls, "__module__"): + module_name = cls.__module__.split(".")[-1] + schema_modules[schema_name].add(module_name) + + # Determine cluster labels: use module name if 1:1, else database schema name + cluster_labels = {} # schema_name -> label + for schema_name, modules in schema_modules.items(): + if len(modules) == 1: + cluster_labels[schema_name] = next(iter(modules)) + else: + cluster_labels[schema_name] = schema_name + + # Disambiguate labels if multiple schemas share the same module name + # (e.g., all defined in __main__ in a notebook) + label_counts = {} + for label in cluster_labels.values(): + label_counts[label] = label_counts.get(label, 0) + 1 + + for schema_name, label in cluster_labels.items(): + if label_counts[label] > 1: + # Multiple schemas share this module name - add schema name + cluster_labels[schema_name] = f"{label} ({schema_name})" + + # Assign alias nodes (orange dots) to the same schema as their child table + for node, data in graph.nodes(data=True): + if data.get("node_type") is _AliasNode: + # Find the child (successor) - the table that declares the renamed FK + successors = list(graph.successors(node)) + if successors and successors[0] in schema_map: + schema_map[node] = schema_map[successors[0]] + + # Assign collapsed nodes to their schema so they appear in the cluster + for node, data in graph.nodes(data=True): + if data.get("collapsed") and data.get("schema_name"): + schema_map[node] = data["schema_name"] scale = 1.2 # scaling factor for fonts and boxes label_props = { # http://matplotlib.org/examples/color/named_colors.html @@ -372,8 +682,8 @@ def make_dot(self): color="#FF000020", fontcolor="#7F0000A0", fontsize=round(scale * 10), - size=0.3 * scale, - fixed=True, + size=0.4 * scale, + fixed=False, ), Imported: dict( shape="ellipse", @@ -385,18 +695,33 @@ def make_dot(self): ), Part: dict( shape="plaintext", - color="#0000000", + color="#00000000", fontcolor="black", fontsize=round(scale * 8), size=0.1 * scale, fixed=False, ), + "collapsed": dict( + shape="box3d", + color="#80808060", + fontcolor="#404040", + fontsize=round(scale * 10), + size=0.5 * scale, + fixed=False, + ), } - node_props = {node: label_props[d["node_type"]] for node, d in dict(graph.nodes(data=True)).items()} + # Build node_props, handling collapsed nodes specially + node_props = {} + for node, d in graph.nodes(data=True): + if d.get("collapsed"): + node_props[node] = label_props["collapsed"] + else: + node_props[node] = label_props[d["node_type"]] self._encapsulate_node_names(graph) self._encapsulate_edge_attributes(graph) dot = nx.drawing.nx_pydot.to_pydot(graph) + dot.set_rankdir(direction) for node in dot.get_nodes(): node.set_shape("circle") name = node.get_name().strip('"') @@ -408,17 +733,36 @@ def make_dot(self): node.set_fixedsize("shape" if props["fixed"] else False) node.set_width(props["size"]) node.set_height(props["size"]) - if name.split(".")[0] in self.context: - cls = eval(name, self.context) - assert issubclass(cls, Table) - description = cls().describe(context=self.context).split("\n") - description = ( - ("-" * 30 if q.startswith("---") else (q.replace("->", "→") if "->" in q else q.split(":")[0])) - for q in description - if not q.startswith("#") - ) - node.set_tooltip(" ".join(description)) - node.set_label("<" + name + ">" if node.get("distinguished") == "True" else name) + + # Handle collapsed nodes specially + node_data = graph.nodes.get(f'"{name}"', {}) + if node_data.get("collapsed"): + table_count = node_data.get("table_count", 0) + label = f"({table_count} tables)" if table_count != 1 else "(1 table)" + node.set_label(label) + node.set_tooltip(f"Collapsed schema: {table_count} tables") + else: + cls = self._resolve_class(name) + if cls is not None: + description = cls().describe(context=self.context).split("\n") + description = ( + ( + "-" * 30 + if q.startswith("---") + else (q.replace("->", "→") if "->" in q else q.split(":")[0]) + ) + for q in description + if not q.startswith("#") + ) + node.set_tooltip(" ".join(description)) + # Strip module prefix from label if it matches the cluster label + display_name = name + schema_name = schema_map.get(name) + if schema_name and "." in name: + prefix = name.rsplit(".", 1)[0] + if prefix == cluster_labels.get(schema_name): + display_name = name.rsplit(".", 1)[1] + node.set_label("<" + display_name + ">" if node.get("distinguished") == "True" else display_name) node.set_color(props["color"]) node.set_style("filled") @@ -430,11 +774,41 @@ def make_dot(self): if props is None: raise DataJointError("Could not find edge with source '{}' and destination '{}'".format(src, dest)) edge.set_color("#00000040") - edge.set_style("solid" if props["primary"] else "dashed") - master_part = graph.nodes[dest]["node_type"] is Part and dest.startswith(src + ".") + edge.set_style("solid" if props.get("primary") else "dashed") + dest_node_type = graph.nodes[dest].get("node_type") + master_part = dest_node_type is Part and dest.startswith(src + ".") edge.set_weight(3 if master_part else 1) edge.set_arrowhead("none") - edge.set_penwidth(0.75 if props["multi"] else 2) + edge.set_penwidth(0.75 if props.get("multi") else 2) + + # Group nodes into schema clusters (always on) + if schema_map: + import pydot + + # Group nodes by schema + schemas = {} + for node in list(dot.get_nodes()): + name = node.get_name().strip('"') + schema_name = schema_map.get(name) + if schema_name: + if schema_name not in schemas: + schemas[schema_name] = [] + schemas[schema_name].append(node) + + # Create clusters for each schema + # Use Python module name if 1:1 mapping, otherwise database schema name + for schema_name, nodes in schemas.items(): + label = cluster_labels.get(schema_name, schema_name) + cluster = pydot.Cluster( + f"cluster_{schema_name}", + label=label, + style="dashed", + color="gray", + fontcolor="gray", + ) + for node in nodes: + cluster.add_node(node) + dot.add_subgraph(cluster) return dot @@ -452,6 +826,159 @@ def make_image(self): else: raise DataJointError("pyplot was not imported") + def make_mermaid(self) -> str: + """ + Generate Mermaid diagram syntax. + + Produces a flowchart in Mermaid syntax that can be rendered in + Markdown documentation, GitHub, or https://mermaid.live. + + Returns + ------- + str + Mermaid flowchart syntax. + + Notes + ----- + Layout direction is controlled via ``dj.config.display.diagram_direction``. + Tables are grouped by schema using Mermaid subgraphs, with the Python + module name shown as the group label when available. + + Examples + -------- + >>> print(dj.Diagram(schema).make_mermaid()) + flowchart TB + subgraph my_pipeline + Mouse[Mouse]:::manual + Session[Session]:::manual + Neuron([Neuron]):::computed + end + Mouse --> Session + Session --> Neuron + """ + graph = self._make_graph() + direction = config.display.diagram_direction + + # Apply collapse logic if needed + graph, collapsed_counts = self._apply_collapse(graph) + + # Build schema mapping for grouping + schema_map = {} # class_name -> schema_name + schema_modules = {} # schema_name -> set of module names + + for full_name in self.nodes_to_show: + parts = full_name.replace('"', "`").split("`") + if len(parts) >= 2: + schema_name = parts[1] + class_name = lookup_class_name(full_name, self.context) or full_name + schema_map[class_name] = schema_name + + # Collect all module names for this schema + if schema_name not in schema_modules: + schema_modules[schema_name] = set() + cls = self._resolve_class(class_name) + if cls is not None and hasattr(cls, "__module__"): + module_name = cls.__module__.split(".")[-1] + schema_modules[schema_name].add(module_name) + + # Determine cluster labels: use module name if 1:1, else database schema name + cluster_labels = {} + for schema_name, modules in schema_modules.items(): + if len(modules) == 1: + cluster_labels[schema_name] = next(iter(modules)) + else: + cluster_labels[schema_name] = schema_name + + # Assign alias nodes to the same schema as their child table + for node, data in graph.nodes(data=True): + if data.get("node_type") is _AliasNode: + successors = list(graph.successors(node)) + if successors and successors[0] in schema_map: + schema_map[node] = schema_map[successors[0]] + + lines = [f"flowchart {direction}"] + + # Define class styles matching Graphviz colors + lines.append(" classDef manual fill:#90EE90,stroke:#006400") + lines.append(" classDef lookup fill:#D3D3D3,stroke:#696969") + lines.append(" classDef computed fill:#FFB6C1,stroke:#8B0000") + lines.append(" classDef imported fill:#ADD8E6,stroke:#00008B") + lines.append(" classDef part fill:#FFFFFF,stroke:#000000") + lines.append(" classDef collapsed fill:#808080,stroke:#404040") + lines.append("") + + # Shape mapping: Manual=box, Computed/Imported=stadium, Lookup/Part=box + shape_map = { + Manual: ("[", "]"), # box + Lookup: ("[", "]"), # box + Computed: ("([", "])"), # stadium/pill + Imported: ("([", "])"), # stadium/pill + Part: ("[", "]"), # box + _AliasNode: ("((", "))"), # circle + None: ("((", "))"), # circle + } + + tier_class = { + Manual: "manual", + Lookup: "lookup", + Computed: "computed", + Imported: "imported", + Part: "part", + _AliasNode: "", + None: "", + } + + # Group nodes by schema into subgraphs (including collapsed nodes) + schemas = {} + for node, data in graph.nodes(data=True): + if data.get("collapsed"): + # Collapsed nodes use their schema_name attribute + schema_name = data.get("schema_name") + else: + schema_name = schema_map.get(node) + if schema_name: + if schema_name not in schemas: + schemas[schema_name] = [] + schemas[schema_name].append((node, data)) + + # Add nodes grouped by schema subgraphs + for schema_name, nodes in schemas.items(): + label = cluster_labels.get(schema_name, schema_name) + lines.append(f" subgraph {label}") + for node, data in nodes: + safe_id = node.replace(".", "_").replace(" ", "_") + if data.get("collapsed"): + # Collapsed node - show only table count + table_count = data.get("table_count", 0) + count_text = f"{table_count} tables" if table_count != 1 else "1 table" + lines.append(f' {safe_id}[["({count_text})"]]:::collapsed') + else: + # Regular node + tier = data.get("node_type") + left, right = shape_map.get(tier, ("[", "]")) + cls = tier_class.get(tier, "") + # Strip module prefix from display name if it matches the cluster label + display_name = node + if "." in node: + prefix = node.rsplit(".", 1)[0] + if prefix == label: + display_name = node.rsplit(".", 1)[1] + class_suffix = f":::{cls}" if cls else "" + lines.append(f" {safe_id}{left}{display_name}{right}{class_suffix}") + lines.append(" end") + + lines.append("") + + # Add edges + for src, dest, data in graph.edges(data=True): + safe_src = src.replace(".", "_").replace(" ", "_") + safe_dest = dest.replace(".", "_").replace(" ", "_") + # Solid arrow for primary FK, dotted for non-primary + style = "-->" if data.get("primary") else "-.->" + lines.append(f" {safe_src} {style} {safe_dest}") + + return "\n".join(lines) + def _repr_svg_(self): return self.make_svg()._repr_svg_() @@ -472,24 +999,38 @@ def save(self, filename: str, format: str | None = None) -> None: filename : str Output filename. format : str, optional - File format (``'png'`` or ``'svg'``). Inferred from extension if None. + File format (``'png'``, ``'svg'``, or ``'mermaid'``). + Inferred from extension if None. Raises ------ DataJointError If format is unsupported. + + Notes + ----- + Layout direction is controlled via ``dj.config.display.diagram_direction``. + Tables are grouped by schema, with the Python module name shown as the + group label when available. """ if format is None: if filename.lower().endswith(".png"): format = "png" elif filename.lower().endswith(".svg"): format = "svg" + elif filename.lower().endswith((".mmd", ".mermaid")): + format = "mermaid" + if format is None: + raise DataJointError("Could not infer format from filename. Specify format explicitly.") if format.lower() == "png": with open(filename, "wb") as f: f.write(self.make_png().getbuffer().tobytes()) elif format.lower() == "svg": with open(filename, "w") as f: f.write(self.make_svg().data) + elif format.lower() == "mermaid": + with open(filename, "w") as f: + f.write(self.make_mermaid()) else: raise DataJointError("Unsupported file format") diff --git a/src/datajoint/expression.py b/src/datajoint/expression.py index 5ca7fdaa5..6decaf336 100644 --- a/src/datajoint/expression.py +++ b/src/datajoint/expression.py @@ -104,9 +104,10 @@ def primary_key(self): _subquery_alias_count = count() # count for alias names used in the FROM clause def from_clause(self): + adapter = self.connection.adapter support = ( ( - "(" + src.make_sql() + ") as `$%x`" % next(self._subquery_alias_count) + "({}) as {}".format(src.make_sql(), adapter.quote_identifier(f"${next(self._subquery_alias_count):x}")) if isinstance(src, QueryExpression) else src ) @@ -116,7 +117,8 @@ def from_clause(self): for s, (is_left, using_attrs) in zip(support, self._joins): left_kw = "LEFT " if is_left else "" if using_attrs: - using = "USING ({})".format(", ".join(f"`{a}`" for a in using_attrs)) + quoted_attrs = ", ".join(adapter.quote_identifier(a) for a in using_attrs) + using = f"USING ({quoted_attrs})" clause += f" {left_kw}JOIN {s} {using}" else: # Cross join (no common non-hidden attributes) @@ -134,7 +136,8 @@ def sorting_clauses(self): return "" # Default to KEY ordering if order_by is None (inherit with no existing order) order_by = self._top.order_by if self._top.order_by is not None else ["KEY"] - clause = ", ".join(_wrap_attributes(_flatten_attribute_list(self.primary_key, order_by))) + adapter = self.connection.adapter + clause = ", ".join(_wrap_attributes(_flatten_attribute_list(self.primary_key, order_by), adapter)) if clause: clause = f" ORDER BY {clause}" if self._top.limit is not None: @@ -150,7 +153,7 @@ def make_sql(self, fields=None): """ return "SELECT {distinct}{fields} FROM {from_}{where}{sorting}".format( distinct="DISTINCT " if self._distinct else "", - fields=self.heading.as_sql(fields or self.heading.names), + fields=self.heading.as_sql(fields or self.heading.names, adapter=self.connection.adapter), from_=self.from_clause(), where=self.where_clause(), sorting=self.sorting_clauses(), @@ -454,7 +457,8 @@ def proj(self, *attributes, **named_attributes): from other attributes available before the projection. Each attribute name can only be used once. """ - named_attributes = {k: translate_attribute(v)[1] for k, v in named_attributes.items()} + adapter = self.connection.adapter if hasattr(self, "connection") and self.connection else None + named_attributes = {k: translate_attribute(v, adapter)[1] for k, v in named_attributes.items()} # new attributes in parentheses are included again with the new name without removing original duplication_pattern = re.compile(rf"^\s*\(\s*(?!{'|'.join(CONSTANT_LITERALS)})(?P[a-zA-Z_]\w*)\s*\)\s*$") # attributes without parentheses renamed @@ -876,19 +880,23 @@ def __len__(self): """:return: number of elements in the result set e.g. ``len(q1)``.""" result = self.make_subquery() if self._top else copy.copy(self) has_left_join = any(is_left for is_left, _ in result._joins) - return result.connection.query( - "SELECT {select_} FROM {from_}{where}".format( - select_=( - "count(*)" - if has_left_join - else "count(DISTINCT {fields})".format( - fields=result.heading.as_sql(result.primary_key, include_aliases=False) - ) - ), - from_=result.from_clause(), - where=result.where_clause(), + + # Build COUNT query - PostgreSQL requires different syntax for multi-column DISTINCT + adapter = result.connection.adapter + if has_left_join or len(result.primary_key) > 1: + # Use subquery with DISTINCT for multi-column primary keys (backend-agnostic) + fields = result.heading.as_sql(result.primary_key, include_aliases=False, adapter=adapter) + query = ( + f"SELECT count(*) FROM (" + f"SELECT DISTINCT {fields} FROM {result.from_clause()}{result.where_clause()}" + f") AS distinct_count" ) - ).fetchone()[0] + else: + # Single column - can use count(DISTINCT col) directly + fields = result.heading.as_sql(result.primary_key, include_aliases=False, adapter=adapter) + query = f"SELECT count(DISTINCT {fields}) FROM {result.from_clause()}{result.where_clause()}" + + return result.connection.query(query).fetchone()[0] def __bool__(self): """ @@ -1012,31 +1020,49 @@ def where_clause(self): return "" if not self._left_restrict else " WHERE (%s)" % ")AND(".join(str(s) for s in self._left_restrict) def make_sql(self, fields=None): - fields = self.heading.as_sql(fields or self.heading.names) + adapter = self.connection.adapter + fields = self.heading.as_sql(fields or self.heading.names, adapter=adapter) assert self._grouping_attributes or not self.restriction distinct = set(self.heading.names) == set(self.primary_key) - return "SELECT {distinct}{fields} FROM {from_}{where}{group_by}{sorting}".format( - distinct="DISTINCT " if distinct else "", - fields=fields, - from_=self.from_clause(), - where=self.where_clause(), - group_by=( - "" - if not self.primary_key - else ( - " GROUP BY `%s`" % "`,`".join(self._grouping_attributes) - + ("" if not self.restriction else " HAVING (%s)" % ")AND(".join(self.restriction)) - ) - ), - sorting=self.sorting_clauses(), - ) - def __len__(self): - return self.connection.query( - "SELECT count(1) FROM ({subquery}) `${alias:x}`".format( - subquery=self.make_sql(), alias=next(self._subquery_alias_count) + # PostgreSQL doesn't allow column aliases in HAVING clause (SQL standard). + # For PostgreSQL with restrictions, wrap aggregation in subquery and use WHERE. + use_subquery_for_having = adapter.backend == "postgresql" and self.restriction and self._grouping_attributes + + if use_subquery_for_having: + # Generate inner query without HAVING + inner_sql = "SELECT {distinct}{fields} FROM {from_}{where}{group_by}".format( + distinct="DISTINCT " if distinct else "", + fields=fields, + from_=self.from_clause(), + where=self.where_clause(), + group_by=" GROUP BY {}".format(", ".join(adapter.quote_identifier(col) for col in self._grouping_attributes)), + ) + # Wrap in subquery with WHERE for the HAVING conditions + subquery_alias = adapter.quote_identifier(f"_aggr{next(self._subquery_alias_count)}") + outer_where = " WHERE (%s)" % ")AND(".join(self.restriction) + return f"SELECT * FROM ({inner_sql}) AS {subquery_alias}{outer_where}{self.sorting_clauses()}" + else: + # MySQL path: use HAVING directly + return "SELECT {distinct}{fields} FROM {from_}{where}{group_by}{sorting}".format( + distinct="DISTINCT " if distinct else "", + fields=fields, + from_=self.from_clause(), + where=self.where_clause(), + group_by=( + "" + if not self.primary_key + else ( + " GROUP BY {}".format(", ".join(adapter.quote_identifier(col) for col in self._grouping_attributes)) + + ("" if not self.restriction else " HAVING (%s)" % ")AND(".join(self.restriction)) + ) + ), + sorting=self.sorting_clauses(), ) - ).fetchone()[0] + + def __len__(self): + alias = self.connection.adapter.quote_identifier(f"${next(self._subquery_alias_count):x}") + return self.connection.query(f"SELECT count(1) FROM ({self.make_sql()}) {alias}").fetchone()[0] def __bool__(self): return bool(self.connection.query("SELECT EXISTS({sql})".format(sql=self.make_sql())).fetchone()[0]) @@ -1072,12 +1098,11 @@ def make_sql(self): if not arg1.heading.secondary_attributes and not arg2.heading.secondary_attributes: # no secondary attributes: use UNION DISTINCT fields = arg1.primary_key - return "SELECT * FROM (({sql1}) UNION ({sql2})) as `_u{alias}{sorting}`".format( - sql1=(arg1.make_sql() if isinstance(arg1, Union) else arg1.make_sql(fields)), - sql2=(arg2.make_sql() if isinstance(arg2, Union) else arg2.make_sql(fields)), - alias=next(self.__count), - sorting=self.sorting_clauses(), - ) + alias_name = f"_u{next(self.__count)}{self.sorting_clauses()}" + alias_quoted = self.connection.adapter.quote_identifier(alias_name) + sql1 = arg1.make_sql() if isinstance(arg1, Union) else arg1.make_sql(fields) + sql2 = arg2.make_sql() if isinstance(arg2, Union) else arg2.make_sql(fields) + return f"SELECT * FROM (({sql1}) UNION ({sql2})) as {alias_quoted}" # with secondary attributes, use union of left join with anti-restriction fields = self.heading.names sql1 = arg1.join(arg2, left=True).make_sql(fields) @@ -1093,12 +1118,8 @@ def where_clause(self): raise NotImplementedError("Union does not use a WHERE clause") def __len__(self): - return self.connection.query( - "SELECT count(1) FROM ({subquery}) `${alias:x}`".format( - subquery=self.make_sql(), - alias=next(QueryExpression._subquery_alias_count), - ) - ).fetchone()[0] + alias = self.connection.adapter.quote_identifier(f"${next(QueryExpression._subquery_alias_count):x}") + return self.connection.query(f"SELECT count(1) FROM ({self.make_sql()}) {alias}").fetchone()[0] def __bool__(self): return bool(self.connection.query("SELECT EXISTS({sql})".format(sql=self.make_sql())).fetchone()[0]) @@ -1242,6 +1263,14 @@ def _flatten_attribute_list(primary_key, attrs): yield a -def _wrap_attributes(attr): - for entry in attr: # wrap attribute names in backquotes - yield re.sub(r"\b((?!asc|desc)\w+)\b", r"`\1`", entry, flags=re.IGNORECASE) +def _wrap_attributes(attr, adapter): + """Wrap attribute names with database-specific quotes.""" + for entry in attr: + # Replace word boundaries (not 'asc' or 'desc') with quoted version + def quote_match(match): + word = match.group(1) + if word.lower() not in ("asc", "desc"): + return adapter.quote_identifier(word) + return word + + yield re.sub(r"\b((?!asc|desc)\w+)\b", quote_match, entry, flags=re.IGNORECASE) diff --git a/src/datajoint/heading.py b/src/datajoint/heading.py index 99d7246a4..e152e075b 100644 --- a/src/datajoint/heading.py +++ b/src/datajoint/heading.py @@ -133,7 +133,7 @@ def sql_comment(self) -> str: Comment with optional ``:uuid:`` prefix. """ # UUID info is stored in the comment for reconstruction - return (":uuid:" if self.uuid else "") + self.comment + return (":uuid:" if self.uuid else "") + (self.comment or "") @property def sql(self) -> str: @@ -164,8 +164,9 @@ def original_name(self) -> str: """ if self.attribute_expression is None: return self.name - assert self.attribute_expression.startswith("`") - return self.attribute_expression.strip("`") + # Backend-agnostic quote stripping (MySQL uses `, PostgreSQL uses ") + assert self.attribute_expression.startswith(("`", '"')) + return self.attribute_expression.strip('`"') class Heading: @@ -290,7 +291,9 @@ def __repr__(self) -> str: in_key = True ret = "" if self._table_status is not None: - ret += "# " + self.table_status["comment"] + "\n" + comment = self.table_status.get("comment", "") + if comment: + ret += "# " + comment + "\n" for v in self.attributes.values(): if in_key and not v.in_key: ret += "---\n" @@ -319,7 +322,7 @@ def as_dtype(self) -> np.dtype: """ return np.dtype(dict(names=self.names, formats=[v.dtype for v in self.attributes.values()])) - def as_sql(self, fields: list[str], include_aliases: bool = True) -> str: + def as_sql(self, fields: list[str], include_aliases: bool = True, adapter=None) -> str: """ Generate SQL SELECT clause for specified fields. @@ -329,20 +332,37 @@ def as_sql(self, fields: list[str], include_aliases: bool = True) -> str: Attribute names to include. include_aliases : bool, optional Include AS clauses for computed attributes. Default True. + adapter : DatabaseAdapter, optional + Database adapter for identifier quoting. If not provided, attempts + to get from table_info connection. Returns ------- str Comma-separated SQL field list. """ - return ",".join( - ( - "`%s`" % name - if self.attributes[name].attribute_expression is None - else self.attributes[name].attribute_expression + (" as `%s`" % name if include_aliases else "") - ) - for name in fields - ) + # Get adapter for proper identifier quoting + if adapter is None and self.table_info and "conn" in self.table_info and self.table_info["conn"]: + adapter = self.table_info["conn"].adapter + + def quote(name): + # Use adapter if available, otherwise use ANSI SQL double quotes (not backticks) + return adapter.quote_identifier(name) if adapter else f'"{name}"' + + def render_field(name): + attr = self.attributes[name] + if attr.attribute_expression is None: + return quote(name) + else: + # Translate expression for backend compatibility (e.g., GROUP_CONCAT ↔ STRING_AGG) + expr = attr.attribute_expression + if adapter: + expr = adapter.translate_expression(expr) + if include_aliases: + return f"{expr} as {quote(name)}" + return expr + + return ",".join(render_field(name) for name in fields) def __iter__(self): return iter(self.attributes) @@ -350,38 +370,42 @@ def __iter__(self): def _init_from_database(self) -> None: """Initialize heading from an existing database table.""" conn, database, table_name, context = (self.table_info[k] for k in ("conn", "database", "table_name", "context")) + adapter = conn.adapter + + # Get table metadata info = conn.query( - 'SHOW TABLE STATUS FROM `{database}` WHERE name="{table_name}"'.format(table_name=table_name, database=database), + adapter.get_table_info_sql(database, table_name), as_dict=True, ).fetchone() if info is None: - raise DataJointError( - "The table `{database}`.`{table_name}` is not defined.".format(table_name=table_name, database=database) - ) + raise DataJointError(f"The table {database}.{table_name} is not defined.") + # Normalize table_comment to comment for backward compatibility self._table_status = {k.lower(): v for k, v in info.items()} + if "table_comment" in self._table_status: + self._table_status["comment"] = self._table_status["table_comment"] + + # Get column information cur = conn.query( - "SHOW FULL COLUMNS FROM `{table_name}` IN `{database}`".format(table_name=table_name, database=database), + adapter.get_columns_sql(database, table_name), as_dict=True, ) - attributes = cur.fetchall() - - rename_map = { - "Field": "name", - "Type": "type", - "Null": "nullable", - "Default": "default", - "Key": "in_key", - "Comment": "comment", - } + # Parse columns using adapter-specific parser + raw_attributes = cur.fetchall() + attributes = [adapter.parse_column_info(row) for row in raw_attributes] - fields_to_drop = ("Privileges", "Collation") + # Get primary key information and mark primary key columns + pk_query = conn.query( + adapter.get_primary_key_sql(database, table_name), + as_dict=True, + ) + pk_columns = {row["column_name"] for row in pk_query.fetchall()} + for attr in attributes: + if attr["name"] in pk_columns: + attr["key"] = "PRI" - # rename and drop attributes - attributes = [ - {rename_map[k] if k in rename_map else k: v for k, v in x.items() if k not in fields_to_drop} for x in attributes - ] numeric_types = { + # MySQL types ("float", False): np.float64, ("float", True): np.float64, ("double", False): np.float64, @@ -396,6 +420,13 @@ def _init_from_database(self) -> None: ("int", True): np.int64, ("bigint", False): np.int64, ("bigint", True): np.uint64, + # PostgreSQL types + ("integer", False): np.int64, + ("integer", True): np.int64, + ("real", False): np.float64, + ("real", True): np.float64, + ("double precision", False): np.float64, + ("double precision", True): np.float64, } sql_literals = ["CURRENT_TIMESTAMP"] @@ -403,9 +434,9 @@ def _init_from_database(self) -> None: # additional attribute properties for attr in attributes: attr.update( - in_key=(attr["in_key"] == "PRI"), - nullable=attr["nullable"] == "YES", - autoincrement=bool(re.search(r"auto_increment", attr["Extra"], flags=re.I)), + in_key=(attr["key"] == "PRI"), + nullable=attr["nullable"], # Already boolean from parse_column_info + autoincrement=bool(re.search(r"auto_increment", attr["extra"], flags=re.I)), numeric=any(TYPE_PATTERN[t].match(attr["type"]) for t in ("DECIMAL", "INTEGER", "FLOAT")), string=any(TYPE_PATTERN[t].match(attr["type"]) for t in ("ENUM", "TEMPORAL", "STRING")), is_blob=any(TYPE_PATTERN[t].match(attr["type"]) for t in ("BYTES", "NATIVE_BLOB")), @@ -421,10 +452,12 @@ def _init_from_database(self) -> None: if any(TYPE_PATTERN[t].match(attr["type"]) for t in ("INTEGER", "FLOAT")): attr["type"] = re.sub(r"\(\d+\)", "", attr["type"], count=1) # strip size off integers and floats attr["unsupported"] = not any((attr["is_blob"], attr["numeric"], attr["numeric"])) - attr.pop("Extra") + attr.pop("extra") + attr.pop("key") # process custom DataJoint types stored in comment - special = re.match(r":(?P[^:]+):(?P.*)", attr["comment"]) + comment = attr["comment"] or "" # Handle None for PostgreSQL + special = re.match(r":(?P[^:]+):(?P.*)", comment) if special: special = special.groupdict() attr["comment"] = special["comment"] # Always update the comment @@ -519,21 +552,32 @@ def _init_from_database(self) -> None: # Read and tabulate secondary indexes keys = defaultdict(dict) for item in conn.query( - "SHOW KEYS FROM `{db}`.`{tab}`".format(db=database, tab=table_name), + adapter.get_indexes_sql(database, table_name), as_dict=True, ): - if item["Key_name"] != "PRIMARY": - keys[item["Key_name"]][item["Seq_in_index"]] = dict( - column=item["Column_name"] or f"({item['Expression']})".replace(r"\'", "'"), - unique=(item["Non_unique"] == 0), - nullable=item["Null"].lower() == "yes", - ) + # Note: adapter.get_indexes_sql() already filters out PRIMARY key + # MySQL/PostgreSQL adapters return: index_name, column_name, non_unique + index_name = item.get("index_name") or item.get("Key_name") + seq = item.get("seq_in_index") or item.get("Seq_in_index") or len(keys[index_name]) + 1 + column = item.get("column_name") or item.get("Column_name") + # MySQL EXPRESSION column stores escaped single quotes - unescape them + if column: + column = column.replace("\\'", "'") + non_unique = item.get("non_unique") or item.get("Non_unique") + nullable = item.get("nullable") or (item.get("Null", "NO").lower() == "yes") + + keys[index_name][seq] = dict( + column=column, + unique=(non_unique == 0 or not non_unique), + nullable=nullable, + ) self.indexes = { - tuple(item[k]["column"] for k in sorted(item.keys())): dict( + tuple(item[k]["column"] for k in sorted(item.keys()) if item[k]["column"] is not None): dict( unique=item[1]["unique"], nullable=any(v["nullable"] for v in item.values()), ) for item in keys.values() + if any(item[k]["column"] is not None for k in item.keys()) } def select(self, select_list, rename_map=None, compute_map=None): @@ -548,6 +592,8 @@ def select(self, select_list, rename_map=None, compute_map=None): """ rename_map = rename_map or {} compute_map = compute_map or {} + # Get adapter for proper identifier quoting + adapter = self.table_info["conn"].adapter if self.table_info else None copy_attrs = list() for name in self.attributes: if name in select_list: @@ -557,7 +603,7 @@ def select(self, select_list, rename_map=None, compute_map=None): dict( self.attributes[old_name].todict(), name=new_name, - attribute_expression="`%s`" % old_name, + attribute_expression=(adapter.quote_identifier(old_name) if adapter else f"`{old_name}`"), ) for new_name, old_name in rename_map.items() if old_name == name @@ -567,7 +613,10 @@ def select(self, select_list, rename_map=None, compute_map=None): dict(default_attribute_properties, name=new_name, attribute_expression=expr) for new_name, expr in compute_map.items() ) - return Heading(chain(copy_attrs, compute_attrs), lineage_available=self._lineage_available) + # Inherit table_info so the new heading has access to the adapter + new_heading = Heading(chain(copy_attrs, compute_attrs), lineage_available=self._lineage_available) + new_heading.table_info = self.table_info + return new_heading def _join_dependent(self, dependent): """Build attribute list when self → dependent: PK = PK(self), self's attrs first.""" diff --git a/src/datajoint/jobs.py b/src/datajoint/jobs.py index d54109db5..23699da30 100644 --- a/src/datajoint/jobs.py +++ b/src/datajoint/jobs.py @@ -249,7 +249,7 @@ def pending(self) -> "Job": Job Restricted query with ``status='pending'``. """ - return self & 'status="pending"' + return self & "status='pending'" @property def reserved(self) -> "Job": @@ -261,7 +261,7 @@ def reserved(self) -> "Job": Job Restricted query with ``status='reserved'``. """ - return self & 'status="reserved"' + return self & "status='reserved'" @property def errors(self) -> "Job": @@ -273,7 +273,7 @@ def errors(self) -> "Job": Job Restricted query with ``status='error'``. """ - return self & 'status="error"' + return self & "status='error'" @property def ignored(self) -> "Job": @@ -285,7 +285,7 @@ def ignored(self) -> "Job": Job Restricted query with ``status='ignore'``. """ - return self & 'status="ignore"' + return self & "status='ignore'" @property def completed(self) -> "Job": @@ -297,7 +297,7 @@ def completed(self) -> "Job": Job Restricted query with ``status='success'``. """ - return self & 'status="success"' + return self & "status='success'" # ------------------------------------------------------------------------- # Core job management methods @@ -376,7 +376,8 @@ def refresh( if new_key_list: # Use server time for scheduling (CURRENT_TIMESTAMP(3) matches datetime(3) precision) - scheduled_time = self.connection.query(f"SELECT CURRENT_TIMESTAMP(3) + INTERVAL {delay} SECOND").fetchone()[0] + interval_expr = self.adapter.interval_expr(delay, "second") + scheduled_time = self.connection.query(f"SELECT CURRENT_TIMESTAMP(3) + {interval_expr}").fetchone()[0] for key in new_key_list: job_entry = { @@ -404,7 +405,8 @@ def refresh( # 3. Remove stale jobs (not ignore status) - use server CURRENT_TIMESTAMP for consistent timing if stale_timeout > 0: - old_jobs = self & f"created_time < CURRENT_TIMESTAMP - INTERVAL {stale_timeout} SECOND" & 'status != "ignore"' + stale_interval = self.adapter.interval_expr(stale_timeout, "second") + old_jobs = self & f"created_time < CURRENT_TIMESTAMP - {stale_interval}" & "status != 'ignore'" for key in old_jobs.keys(): # Check if key still in key_source @@ -414,7 +416,8 @@ def refresh( # 4. Handle orphaned reserved jobs - use server CURRENT_TIMESTAMP for consistent timing if orphan_timeout is not None and orphan_timeout > 0: - orphaned_jobs = self.reserved & f"reserved_time < CURRENT_TIMESTAMP - INTERVAL {orphan_timeout} SECOND" + orphan_interval = self.adapter.interval_expr(orphan_timeout, "second") + orphaned_jobs = self.reserved & f"reserved_time < CURRENT_TIMESTAMP - {orphan_interval}" for key in orphaned_jobs.keys(): (self & key).delete_quick() @@ -441,7 +444,7 @@ def reserve(self, key: dict) -> bool: True if reservation successful, False if job not available. """ # Check if job is pending and scheduled (use CURRENT_TIMESTAMP(3) for datetime(3) precision) - job = (self & key & 'status="pending"' & "scheduled_time <= CURRENT_TIMESTAMP(3)").to_dicts() + job = (self & key & "status='pending'" & "scheduled_time <= CURRENT_TIMESTAMP(3)").to_dicts() if not job: return False diff --git a/src/datajoint/lineage.py b/src/datajoint/lineage.py index d40ed8dd8..bb911a876 100644 --- a/src/datajoint/lineage.py +++ b/src/datajoint/lineage.py @@ -38,17 +38,30 @@ def ensure_lineage_table(connection, database): database : str The schema/database name. """ - connection.query( - """ - CREATE TABLE IF NOT EXISTS `{database}`.`~lineage` ( - table_name VARCHAR(64) NOT NULL COMMENT 'table name within the schema', - attribute_name VARCHAR(64) NOT NULL COMMENT 'attribute name', - lineage VARCHAR(255) NOT NULL COMMENT 'origin: schema.table.attribute', - PRIMARY KEY (table_name, attribute_name) - ) ENGINE=InnoDB - """.format(database=database) + adapter = connection.adapter + + # Build fully qualified table name + lineage_table = f"{adapter.quote_identifier(database)}.{adapter.quote_identifier('~lineage')}" + + # Build column definitions using adapter + columns = [ + adapter.format_column_definition("table_name", "VARCHAR(64)", nullable=False, comment="table name within the schema"), + adapter.format_column_definition("attribute_name", "VARCHAR(64)", nullable=False, comment="attribute name"), + adapter.format_column_definition("lineage", "VARCHAR(255)", nullable=False, comment="origin: schema.table.attribute"), + ] + + # Build PRIMARY KEY using adapter + pk_cols = adapter.quote_identifier("table_name") + ", " + adapter.quote_identifier("attribute_name") + pk_clause = f"PRIMARY KEY ({pk_cols})" + + sql = ( + f"CREATE TABLE IF NOT EXISTS {lineage_table} (\n" + + ",\n".join(columns + [pk_clause]) + + f"\n) {adapter.table_options_clause()}" ) + connection.query(sql) + def lineage_table_exists(connection, database): """ @@ -99,11 +112,14 @@ def get_lineage(connection, database, table_name, attribute_name): if not lineage_table_exists(connection, database): return None + adapter = connection.adapter + lineage_table = f"{adapter.quote_identifier(database)}.{adapter.quote_identifier('~lineage')}" + result = connection.query( - """ - SELECT lineage FROM `{database}`.`~lineage` + f""" + SELECT lineage FROM {lineage_table} WHERE table_name = %s AND attribute_name = %s - """.format(database=database), + """, args=(table_name, attribute_name), ).fetchone() return result[0] if result else None @@ -130,11 +146,14 @@ def get_table_lineages(connection, database, table_name): if not lineage_table_exists(connection, database): return {} + adapter = connection.adapter + lineage_table = f"{adapter.quote_identifier(database)}.{adapter.quote_identifier('~lineage')}" + results = connection.query( - """ - SELECT attribute_name, lineage FROM `{database}`.`~lineage` + f""" + SELECT attribute_name, lineage FROM {lineage_table} WHERE table_name = %s - """.format(database=database), + """, args=(table_name,), ).fetchall() return {row[0]: row[1] for row in results} @@ -159,10 +178,13 @@ def get_schema_lineages(connection, database): if not lineage_table_exists(connection, database): return {} + adapter = connection.adapter + lineage_table = f"{adapter.quote_identifier(database)}.{adapter.quote_identifier('~lineage')}" + results = connection.query( - """ - SELECT table_name, attribute_name, lineage FROM `{database}`.`~lineage` - """.format(database=database), + f""" + SELECT table_name, attribute_name, lineage FROM {lineage_table} + """, ).fetchall() return {f"{database}.{table}.{attr}": lineage for table, attr, lineage in results} @@ -184,18 +206,25 @@ def insert_lineages(connection, database, entries): if not entries: return ensure_lineage_table(connection, database) - # Build a single INSERT statement with multiple values for atomicity - placeholders = ", ".join(["(%s, %s, %s)"] * len(entries)) + + adapter = connection.adapter + lineage_table = f"{adapter.quote_identifier(database)}.{adapter.quote_identifier('~lineage')}" + + # Build backend-agnostic upsert statement + columns = ["table_name", "attribute_name", "lineage"] + primary_key = ["table_name", "attribute_name"] + + sql = adapter.upsert_on_duplicate_sql( + lineage_table, + columns, + primary_key, + len(entries), + ) + # Flatten the entries into a single args tuple args = tuple(val for entry in entries for val in entry) - connection.query( - """ - INSERT INTO `{database}`.`~lineage` (table_name, attribute_name, lineage) - VALUES {placeholders} - ON DUPLICATE KEY UPDATE lineage = VALUES(lineage) - """.format(database=database, placeholders=placeholders), - args=args, - ) + + connection.query(sql, args=args) def delete_table_lineages(connection, database, table_name): @@ -213,11 +242,15 @@ def delete_table_lineages(connection, database, table_name): """ if not lineage_table_exists(connection, database): return + + adapter = connection.adapter + lineage_table = f"{adapter.quote_identifier(database)}.{adapter.quote_identifier('~lineage')}" + connection.query( - """ - DELETE FROM `{database}`.`~lineage` + f""" + DELETE FROM {lineage_table} WHERE table_name = %s - """.format(database=database), + """, args=(table_name,), ) @@ -251,8 +284,11 @@ def rebuild_schema_lineage(connection, database): # Ensure the lineage table exists ensure_lineage_table(connection, database) + adapter = connection.adapter + lineage_table = f"{adapter.quote_identifier(database)}.{adapter.quote_identifier('~lineage')}" + # Clear all existing lineage entries for this schema - connection.query(f"DELETE FROM `{database}`.`~lineage`") + connection.query(f"DELETE FROM {lineage_table}") # Get all tables in the schema (excluding hidden tables) tables_result = connection.query( diff --git a/src/datajoint/schemas.py b/src/datajoint/schemas.py index 98faa83f2..1a9958a21 100644 --- a/src/datajoint/schemas.py +++ b/src/datajoint/schemas.py @@ -190,7 +190,8 @@ def activate( # create database logger.debug("Creating schema `{name}`.".format(name=schema_name)) try: - self.connection.query("CREATE DATABASE `{name}`".format(name=schema_name)) + create_sql = self.connection.adapter.create_schema_sql(schema_name) + self.connection.query(create_sql) except AccessError: raise DataJointError( "Schema `{name}` does not exist and could not be created. Check permissions.".format(name=schema_name) @@ -413,7 +414,8 @@ def drop(self, prompt: bool | None = None) -> None: elif not prompt or user_choice("Proceed to delete entire schema `%s`?" % self.database, default="no") == "yes": logger.debug("Dropping `{database}`.".format(database=self.database)) try: - self.connection.query("DROP DATABASE `{database}`".format(database=self.database)) + drop_sql = self.connection.adapter.drop_schema_sql(self.database) + self.connection.query(drop_sql) logger.debug("Schema `{database}` was dropped successfully.".format(database=self.database)) except AccessError: raise AccessError( @@ -515,13 +517,17 @@ def jobs(self) -> list[Job]: jobs_list = [] # Get all existing job tables (~~prefix) - # Note: %% escapes the % in pymysql - result = self.connection.query(f"SHOW TABLES IN `{self.database}` LIKE '~~%%'").fetchall() + # Note: %% escapes the % in pymysql/psycopg2 + adapter = self.connection.adapter + sql = adapter.list_tables_sql(self.database, pattern="~~%%") + result = self.connection.query(sql).fetchall() existing_job_tables = {row[0] for row in result} # Iterate over auto-populated tables and check if their job table exists for table_name in self.list_tables(): - table = FreeTable(self.connection, f"`{self.database}`.`{table_name}`") + adapter = self.connection.adapter + full_name = f"{adapter.quote_identifier(self.database)}." f"{adapter.quote_identifier(table_name)}" + table = FreeTable(self.connection, full_name) tier = _get_tier(table.full_table_name) if tier in (Computed, Imported): # Compute expected job table name: ~~base_name @@ -694,7 +700,8 @@ def get_table(self, name: str) -> FreeTable: if table_name is None: raise DataJointError(f"Table `{name}` does not exist in schema `{self.database}`.") - full_name = f"`{self.database}`.`{table_name}`" + adapter = self.connection.adapter + full_name = f"{adapter.quote_identifier(self.database)}.{adapter.quote_identifier(table_name)}" return FreeTable(self.connection, full_name) def __getitem__(self, name: str) -> FreeTable: @@ -892,7 +899,7 @@ def virtual_schema( -------- >>> lab = dj.virtual_schema('my_lab') >>> lab.Subject.fetch() - >>> lab.Session & 'subject_id="M001"' + >>> lab.Session & "subject_id='M001'" See Also -------- diff --git a/src/datajoint/settings.py b/src/datajoint/settings.py index e9b6f6570..ddd1b487a 100644 --- a/src/datajoint/settings.py +++ b/src/datajoint/settings.py @@ -15,6 +15,10 @@ >>> import datajoint as dj >>> dj.config.database.host 'localhost' +>>> dj.config.database.backend +'mysql' +>>> dj.config.database.port # Auto-detects: 3306 for MySQL, 5432 for PostgreSQL +3306 >>> with dj.config.override(safemode=False): ... # dangerous operations here ... pass @@ -43,7 +47,7 @@ from pathlib import Path from typing import Any, Iterator, Literal -from pydantic import Field, SecretStr, field_validator +from pydantic import Field, SecretStr, field_validator, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict from .errors import DataJointError @@ -59,8 +63,10 @@ "database.host": "DJ_HOST", "database.user": "DJ_USER", "database.password": "DJ_PASS", + "database.backend": "DJ_BACKEND", "database.port": "DJ_PORT", "loglevel": "DJ_LOG_LEVEL", + "display.diagram_direction": "DJ_DIAGRAM_DIRECTION", } Role = Enum("Role", "manual lookup imported computed job") @@ -182,9 +188,21 @@ class DatabaseSettings(BaseSettings): host: str = Field(default="localhost", validation_alias="DJ_HOST") user: str | None = Field(default=None, validation_alias="DJ_USER") password: SecretStr | None = Field(default=None, validation_alias="DJ_PASS") - port: int = Field(default=3306, validation_alias="DJ_PORT") + backend: Literal["mysql", "postgresql"] = Field( + default="mysql", + validation_alias="DJ_BACKEND", + description="Database backend: 'mysql' or 'postgresql'", + ) + port: int | None = Field(default=None, validation_alias="DJ_PORT") reconnect: bool = True - use_tls: bool | None = None + use_tls: bool | None = Field(default=None, validation_alias="DJ_USE_TLS") + + @model_validator(mode="after") + def set_default_port_from_backend(self) -> "DatabaseSettings": + """Set default port based on backend if not explicitly provided.""" + if self.port is None: + self.port = 5432 if self.backend == "postgresql" else 3306 + return self class ConnectionSettings(BaseSettings): @@ -204,6 +222,11 @@ class DisplaySettings(BaseSettings): limit: int = 12 width: int = 14 show_tuple_count: bool = True + diagram_direction: Literal["TB", "LR"] = Field( + default="LR", + validation_alias="DJ_DIAGRAM_DIRECTION", + description="Default diagram layout direction: 'TB' (top-to-bottom) or 'LR' (left-to-right)", + ) class StoresSettings(BaseSettings): diff --git a/src/datajoint/table.py b/src/datajoint/table.py index 8c672f41e..02f1b2bb6 100644 --- a/src/datajoint/table.py +++ b/src/datajoint/table.py @@ -4,7 +4,6 @@ import itertools import json import logging -import re import uuid import warnings from dataclasses import dataclass, field @@ -30,24 +29,8 @@ logger = logging.getLogger(__name__.split(".")[0]) -foreign_key_error_regexp = re.compile( - r"[\w\s:]*\((?P`[^`]+`.`[^`]+`), " - r"CONSTRAINT (?P`[^`]+`) " - r"(FOREIGN KEY \((?P[^)]+)\) " - r"REFERENCES (?P`[^`]+`(\.`[^`]+`)?) \((?P[^)]+)\)[\s\w]+\))?" -) - -constraint_info_query = " ".join( - """ - SELECT - COLUMN_NAME as fk_attrs, - CONCAT('`', REFERENCED_TABLE_SCHEMA, '`.`', REFERENCED_TABLE_NAME, '`') as parent, - REFERENCED_COLUMN_NAME as pk_attrs - FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE - WHERE - CONSTRAINT_NAME = %s AND TABLE_SCHEMA = %s AND TABLE_NAME = %s; - """.split() -) +# Note: Foreign key error parsing is now handled by adapter methods +# Legacy regexp and query kept for reference but no longer used class _RenameMap(tuple): @@ -163,14 +146,26 @@ def declare(self, context=None): "Table class name `{name}` is invalid. Please use CamelCase. ".format(name=self.class_name) + "Classes defining tables should be formatted in strict CamelCase." ) - sql, _external_stores, primary_key, fk_attribute_map = declare(self.full_table_name, self.definition, context) + sql, _external_stores, primary_key, fk_attribute_map, pre_ddl, post_ddl = declare( + self.full_table_name, self.definition, context, self.connection.adapter + ) # Call declaration hook for validation (subclasses like AutoPopulate can override) self._declare_check(primary_key, fk_attribute_map) sql = sql.format(database=self.database) try: + # Execute pre-DDL statements (e.g., CREATE TYPE for PostgreSQL enums) + for ddl in pre_ddl: + try: + self.connection.query(ddl.format(database=self.database)) + except Exception: + # Ignore errors (type may already exist) + pass self.connection.query(sql) + # Execute post-DDL statements (e.g., COMMENT ON for PostgreSQL) + for ddl in post_ddl: + self.connection.query(ddl.format(database=self.database)) except AccessError: # Only suppress if table already exists (idempotent declaration) # Otherwise raise - user needs to know about permission issues @@ -225,8 +220,8 @@ def _populate_lineage(self, primary_key, fk_attribute_map): # FK attributes: copy lineage from parent (whether in PK or not) for attr, (parent_table, parent_attr) in fk_attribute_map.items(): - # Parse parent table name: `schema`.`table` -> (schema, table) - parent_clean = parent_table.replace("`", "") + # Parse parent table name: `schema`.`table` or "schema"."table" -> (schema, table) + parent_clean = parent_table.replace("`", "").replace('"', "") if "." in parent_clean: parent_db, parent_tbl = parent_clean.split(".", 1) else: @@ -270,7 +265,7 @@ def alter(self, prompt=True, context=None): context = dict(frame.f_globals, **frame.f_locals) del frame old_definition = self.describe(context=context) - sql, _external_stores = alter(self.definition, old_definition, context) + sql, _external_stores = alter(self.definition, old_definition, context, self.connection.adapter) if not sql: if prompt: logger.warning("Nothing to alter.") @@ -384,12 +379,8 @@ def is_declared(self): """ :return: True is the table is declared in the schema. """ - return ( - self.connection.query( - 'SHOW TABLES in `{database}` LIKE "{table_name}"'.format(database=self.database, table_name=self.table_name) - ).rowcount - > 0 - ) + query = self.connection.adapter.get_table_info_sql(self.database, self.table_name) + return self.connection.query(query).rowcount > 0 @property def full_table_name(self): @@ -401,7 +392,12 @@ def full_table_name(self): f"Class {self.__class__.__name__} is not associated with a schema. " "Apply a schema decorator or use schema() to bind it." ) - return r"`{0:s}`.`{1:s}`".format(self.database, self.table_name) + return f"{self.adapter.quote_identifier(self.database)}.{self.adapter.quote_identifier(self.table_name)}" + + @property + def adapter(self): + """Database adapter for backend-agnostic SQL generation.""" + return self.connection.adapter def update1(self, row): """ @@ -438,9 +434,10 @@ def update1(self, row): raise DataJointError("Update can only be applied to one existing entry.") # UPDATE query row = [self.__make_placeholder(k, v) for k, v in row.items() if k not in self.primary_key] + assignments = ",".join(f"{self.adapter.quote_identifier(r[0])}={r[1]}" for r in row) query = "UPDATE {table} SET {assignments} WHERE {where}".format( table=self.full_table_name, - assignments=",".join("`%s`=%s" % r[:2] for r in row), + assignments=assignments, where=make_condition(self, key, set()), ) self.connection.query(query, args=list(r[2] for r in row if r[2] is not None)) @@ -694,17 +691,16 @@ def insert( except StopIteration: pass fields = list(name for name in rows.heading if name in self.heading) - query = "{command} INTO {table} ({fields}) {select}{duplicate}".format( - command="REPLACE" if replace else "INSERT", - fields="`" + "`,`".join(fields) + "`", - table=self.full_table_name, - select=rows.make_sql(fields), - duplicate=( - " ON DUPLICATE KEY UPDATE `{pk}`={table}.`{pk}`".format(table=self.full_table_name, pk=self.primary_key[0]) - if skip_duplicates - else "" - ), - ) + quoted_fields = ",".join(self.adapter.quote_identifier(f) for f in fields) + + # Duplicate handling (backend-agnostic) + if skip_duplicates: + duplicate = self.adapter.skip_duplicates_clause(self.full_table_name, self.primary_key) + else: + duplicate = "" + + command = "REPLACE" if replace else "INSERT" + query = f"{command} INTO {self.full_table_name} ({quoted_fields}) {rows.make_sql(fields)}{duplicate}" self.connection.query(query) return @@ -736,16 +732,20 @@ def _insert_rows(self, rows, replace, skip_duplicates, ignore_extra_fields): if rows: try: # Handle empty field_list (all-defaults insert) - fields_clause = f"(`{'`,`'.join(field_list)}`)" if field_list else "()" - query = "{command} INTO {destination}{fields} VALUES {placeholders}{duplicate}".format( - command="REPLACE" if replace else "INSERT", - destination=self.from_clause(), - fields=fields_clause, - placeholders=",".join("(" + ",".join(row["placeholders"]) + ")" for row in rows), - duplicate=( - " ON DUPLICATE KEY UPDATE `{pk}`=`{pk}`".format(pk=self.primary_key[0]) if skip_duplicates else "" - ), - ) + if field_list: + fields_clause = f"({','.join(self.adapter.quote_identifier(f) for f in field_list)})" + else: + fields_clause = "()" + + # Build duplicate clause (backend-agnostic) + if skip_duplicates: + duplicate = self.adapter.skip_duplicates_clause(self.full_table_name, self.primary_key) + else: + duplicate = "" + + command = "REPLACE" if replace else "INSERT" + placeholders = ",".join("(" + ",".join(row["placeholders"]) + ")" for row in rows) + query = f"{command} INTO {self.from_clause()}{fields_clause} VALUES {placeholders}{duplicate}" self.connection.query( query, args=list(itertools.chain.from_iterable((v for v in r["values"] if v is not None) for r in rows)), @@ -836,8 +836,9 @@ def delete_quick(self, get_count=False): If this table has populated dependent tables, this will fail. """ query = "DELETE FROM " + self.full_table_name + self.where_clause() - self.connection.query(query) - count = self.connection.query("SELECT ROW_COUNT()").fetchone()[0] if get_count else None + cursor = self.connection.query(query) + # Use cursor.rowcount (DB-API 2.0 standard, works for both MySQL and PostgreSQL) + count = cursor.rowcount if get_count else None return count def delete( @@ -878,44 +879,72 @@ def cascade(table): """service function to perform cascading deletes recursively.""" max_attempts = 50 for _ in range(max_attempts): + # Set savepoint before delete attempt (for PostgreSQL transaction handling) + savepoint_name = f"cascade_delete_{id(table)}" + if transaction: + table.connection.query(f"SAVEPOINT {savepoint_name}") + try: delete_count = table.delete_quick(get_count=True) except IntegrityError as error: - match = foreign_key_error_regexp.match(error.args[0]) + # Rollback to savepoint so we can continue querying (PostgreSQL requirement) + if transaction: + table.connection.query(f"ROLLBACK TO SAVEPOINT {savepoint_name}") + # Use adapter to parse FK error message + match = table.connection.adapter.parse_foreign_key_error(error.args[0]) if match is None: raise DataJointError( - "Cascading deletes failed because the error message is missing foreign key information." + "Cascading deletes failed because the error message is missing foreign key information. " "Make sure you have REFERENCES privilege to all dependent tables." ) from None - match = match.groupdict() - # if schema name missing, use table - if "`.`" not in match["child"]: - match["child"] = "{}.{}".format(table.full_table_name.split(".")[0], match["child"]) - if match["pk_attrs"] is not None: # fully matched, adjusting the keys - match["fk_attrs"] = [k.strip("`") for k in match["fk_attrs"].split(",")] - match["pk_attrs"] = [k.strip("`") for k in match["pk_attrs"].split(",")] - else: # only partially matched, querying with constraint to determine keys - match["fk_attrs"], match["parent"], match["pk_attrs"] = list( - map( - list, - zip( - *table.connection.query( - constraint_info_query, - args=( - match["name"].strip("`"), - *[_.strip("`") for _ in match["child"].split("`.`")], - ), - ).fetchall() - ), - ) + + # Strip quotes from parsed values for backend-agnostic processing + quote_chars = ("`", '"') + + def strip_quotes(s): + if s and any(s.startswith(q) for q in quote_chars): + return s.strip('`"') + return s + + # Extract schema and table name from child (work with unquoted names) + child_table_raw = strip_quotes(match["child"]) + if "." in child_table_raw: + child_parts = child_table_raw.split(".") + child_schema = strip_quotes(child_parts[0]) + child_table_name = strip_quotes(child_parts[1]) + else: + # Add schema from current table + schema_parts = table.full_table_name.split(".") + child_schema = strip_quotes(schema_parts[0]) + child_table_name = child_table_raw + + # If FK/PK attributes not in error message, query information_schema + if match["fk_attrs"] is None or match["pk_attrs"] is None: + constraint_query = table.connection.adapter.get_constraint_info_sql( + strip_quotes(match["name"]), + child_schema, + child_table_name, ) - match["parent"] = match["parent"][0] + + results = table.connection.query( + constraint_query, + args=(strip_quotes(match["name"]), child_schema, child_table_name), + ).fetchall() + if results: + match["fk_attrs"], match["parent"], match["pk_attrs"] = list(map(list, zip(*results))) + match["parent"] = match["parent"][0] # All rows have same parent + + # Build properly quoted full table name for FreeTable + child_full_name = ( + f"{table.connection.adapter.quote_identifier(child_schema)}." + f"{table.connection.adapter.quote_identifier(child_table_name)}" + ) # Restrict child by table if # 1. if table's restriction attributes are not in child's primary key # 2. if child renames any attributes # Otherwise restrict child by table's restriction. - child = FreeTable(table.connection, match["child"]) + child = FreeTable(table.connection, child_full_name) if set(table.restriction_attributes) <= set(child.primary_key) and match["fk_attrs"] == match["pk_attrs"]: child._restriction = table._restriction child._restriction_attributes = table.restriction_attributes @@ -924,7 +953,7 @@ def cascade(table): else: child &= table.proj() - master_name = get_master(child.full_table_name) + master_name = get_master(child.full_table_name, table.connection.adapter) if ( part_integrity == "cascade" and master_name @@ -945,6 +974,9 @@ def cascade(table): else: cascade(child) else: + # Successful delete - release savepoint + if transaction: + table.connection.query(f"RELEASE SAVEPOINT {savepoint_name}") deleted.add(table.full_table_name) logger.info("Deleting {count} rows from {table}".format(count=delete_count, table=table.full_table_name)) break @@ -977,7 +1009,7 @@ def cascade(table): if part_integrity == "enforce": # Avoid deleting from part before master (See issue #151) for part in deleted: - master = get_master(part) + master = get_master(part, self.connection.adapter) if master and master not in deleted: if transaction: self.connection.cancel_transaction() @@ -1020,9 +1052,31 @@ def drop_quick(self): delete_table_lineages(self.connection, self.database, self.table_name) + # For PostgreSQL, get enum types used by this table before dropping + # (we need to query this before the table is dropped) + enum_types_to_drop = [] + adapter = self.connection.adapter + if hasattr(adapter, "get_table_enum_types_sql"): + try: + enum_query = adapter.get_table_enum_types_sql(self.database, self.table_name) + result = self.connection.query(enum_query) + enum_types_to_drop = [row[0] for row in result.fetchall()] + except Exception: + pass # Ignore errors - enum cleanup is best-effort + query = "DROP TABLE %s" % self.full_table_name self.connection.query(query) logger.info("Dropped table %s" % self.full_table_name) + + # For PostgreSQL, clean up enum types after dropping the table + if enum_types_to_drop and hasattr(adapter, "drop_enum_type_ddl"): + for enum_type in enum_types_to_drop: + try: + drop_ddl = adapter.drop_enum_type_ddl(enum_type) + self.connection.query(drop_ddl) + logger.debug("Dropped enum type %s" % enum_type) + except Exception: + pass # Ignore errors - type may be used by other tables else: logger.info("Nothing to drop: table %s is not declared" % self.full_table_name) @@ -1046,7 +1100,7 @@ def drop(self, prompt: bool | None = None): # avoid dropping part tables without their masters: See issue #374 for part in tables: - master = get_master(part) + master = get_master(part, self.connection.adapter) if master and master not in tables: raise DataJointError( "Attempt to drop part table {part} before dropping its master. Drop {master} first.".format( @@ -1089,7 +1143,7 @@ def describe(self, context=None, printout=False): definition = "# " + self.heading.table_status["comment"] + "\n" if self.heading.table_status["comment"] else "" attributes_thus_far = set() attributes_declared = set() - indexes = self.heading.indexes.copy() + indexes = self.heading.indexes.copy() if self.heading.indexes else {} for attr in self.heading.attributes.values(): if in_key and not attr.in_key: definition += "---\n" @@ -1375,7 +1429,8 @@ class FreeTable(Table): """ def __init__(self, conn, full_table_name): - self.database, self._table_name = (s.strip("`") for s in full_table_name.split(".")) + # Backend-agnostic quote stripping (MySQL uses `, PostgreSQL uses ") + self.database, self._table_name = (s.strip('`"') for s in full_table_name.split(".")) self._connection = conn self._support = [full_table_name] self._heading = Heading( @@ -1388,4 +1443,4 @@ def __init__(self, conn, full_table_name): ) def __repr__(self): - return "FreeTable(`%s`.`%s`)\n" % (self.database, self._table_name) + super().__repr__() + return f"FreeTable({self.full_table_name})\n" + super().__repr__() diff --git a/src/datajoint/user_tables.py b/src/datajoint/user_tables.py index 942179685..4b6d0d571 100644 --- a/src/datajoint/user_tables.py +++ b/src/datajoint/user_tables.py @@ -102,10 +102,11 @@ def table_name(cls): @property def full_table_name(cls): - """The fully qualified table name (`database`.`table`).""" + """The fully qualified table name (quoted per backend).""" if cls.database is None: return None - return r"`{0:s}`.`{1:s}`".format(cls.database, cls.table_name) + adapter = cls._connection.adapter + return f"{adapter.quote_identifier(cls.database)}.{adapter.quote_identifier(cls.table_name)}" class UserTable(Table, metaclass=TableMeta): @@ -181,10 +182,11 @@ def table_name(cls): @property def full_table_name(cls): - """The fully qualified table name (`database`.`table`).""" + """The fully qualified table name (quoted per backend).""" if cls.database is None or cls.table_name is None: return None - return r"`{0:s}`.`{1:s}`".format(cls.database, cls.table_name) + adapter = cls._connection.adapter + return f"{adapter.quote_identifier(cls.database)}.{adapter.quote_identifier(cls.table_name)}" @property def master(cls): @@ -274,10 +276,16 @@ class _AliasNode: def _get_tier(table_name): """given the table name, return the user table class.""" - if not table_name.startswith("`"): - return _AliasNode + # Handle both MySQL backticks and PostgreSQL double quotes + if table_name.startswith("`"): + # MySQL format: `schema`.`table_name` + extracted_name = table_name.split("`")[-2] + elif table_name.startswith('"'): + # PostgreSQL format: "schema"."table_name" + extracted_name = table_name.split('"')[-2] else: - try: - return next(tier for tier in user_table_classes if re.fullmatch(tier.tier_regexp, table_name.split("`")[-2])) - except StopIteration: - return None + return _AliasNode + try: + return next(tier for tier in user_table_classes if re.fullmatch(tier.tier_regexp, extracted_name)) + except StopIteration: + return None diff --git a/src/datajoint/utils.py b/src/datajoint/utils.py index e8303a993..4309d78b9 100644 --- a/src/datajoint/utils.py +++ b/src/datajoint/utils.py @@ -25,23 +25,32 @@ def user_choice(prompt, choices=("yes", "no"), default=None): return response -def get_master(full_table_name: str) -> str: +def get_master(full_table_name: str, adapter=None) -> str: """ If the table name is that of a part table, then return what the master table name would be. This follows DataJoint's table naming convention where a master and a part must be in the same schema and the part table is prefixed with the master table name + ``__``. Example: - `ephys`.`session` -- master - `ephys`.`session__recording` -- part + `ephys`.`session` -- master (MySQL) + `ephys`.`session__recording` -- part (MySQL) + "ephys"."session__recording" -- part (PostgreSQL) :param full_table_name: Full table name including part. :type full_table_name: str + :param adapter: Optional database adapter for backend-specific parsing. :return: Supposed master full table name or empty string if not a part table name. :rtype: str """ - match = re.match(r"(?P`\w+`.`\w+)__(?P\w+)`", full_table_name) - return match["master"] + "`" if match else "" + if adapter is not None: + result = adapter.get_master_table_name(full_table_name) + return result if result else "" + + # Fallback: handle both MySQL backticks and PostgreSQL double quotes + match = re.match(r'(?P(?P[`"])[\w]+(?P=q)\.(?P=q)[\w]+)__[\w]+(?P=q)', full_table_name) + if match: + return match["master"] + match["q"] + return "" def is_camel_case(s): diff --git a/src/datajoint/version.py b/src/datajoint/version.py index 98a5f2b93..f19a270de 100644 --- a/src/datajoint/version.py +++ b/src/datajoint/version.py @@ -1,4 +1,4 @@ # version bump auto managed by Github Actions: # label_prs.yaml(prep), release.yaml(bump), post_release.yaml(edit) # manually set this version will be eventually overwritten by the above actions -__version__ = "2.0.0a22" +__version__ = "2.1.0a7" diff --git a/tests/conftest.py b/tests/conftest.py index dc2eb73b6..4d6adf09c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -66,6 +66,12 @@ def pytest_collection_modifyitems(config, items): "stores_config", "mock_stores", } + # Tests that use these fixtures are backend-parameterized + backend_fixtures = { + "backend", + "db_creds_by_backend", + "connection_by_backend", + } for item in items: # Get all fixtures this test uses (directly or indirectly) @@ -80,6 +86,13 @@ def pytest_collection_modifyitems(config, items): if fixturenames & minio_fixtures: item.add_marker(pytest.mark.requires_minio) + # Auto-mark backend-parameterized tests + if fixturenames & backend_fixtures: + # Test will run for both backends - add all backend markers + item.add_marker(pytest.mark.mysql) + item.add_marker(pytest.mark.postgresql) + item.add_marker(pytest.mark.backend_agnostic) + # ============================================================================= # Container Fixtures - Auto-start MySQL and MinIO via testcontainers @@ -101,7 +114,7 @@ def mysql_container(): from testcontainers.mysql import MySqlContainer container = MySqlContainer( - image="mysql:8.0", + image="datajoint/mysql:8.0", # Use datajoint image which has SSL configured username="root", password="password", dbname="test", @@ -118,6 +131,35 @@ def mysql_container(): logger.info("MySQL container stopped") +@pytest.fixture(scope="session") +def postgres_container(): + """Start PostgreSQL container for the test session (or use external).""" + if USE_EXTERNAL_CONTAINERS: + # Use external container - return None, credentials come from env + logger.info("Using external PostgreSQL container") + yield None + return + + from testcontainers.postgres import PostgresContainer + + container = PostgresContainer( + image="postgres:15", + username="postgres", + password="password", + dbname="test", + ) + container.start() + + host = container.get_container_host_ip() + port = container.get_exposed_port(5432) + logger.info(f"PostgreSQL container started at {host}:{port}") + + yield container + + container.stop() + logger.info("PostgreSQL container stopped") + + @pytest.fixture(scope="session") def minio_container(): """Start MinIO container for the test session (or use external).""" @@ -225,6 +267,107 @@ def s3_creds(minio_container) -> Dict: ) +# ============================================================================= +# Backend-Parameterized Fixtures +# ============================================================================= + + +@pytest.fixture(scope="session", params=["mysql", "postgresql"]) +def backend(request): + """Parameterize tests to run against both backends.""" + return request.param + + +@pytest.fixture(scope="session") +def db_creds_by_backend(backend, mysql_container, postgres_container): + """Get root database credentials for the specified backend.""" + if backend == "mysql": + if mysql_container is not None: + host = mysql_container.get_container_host_ip() + port = mysql_container.get_exposed_port(3306) + return { + "backend": "mysql", + "host": f"{host}:{port}", + "user": "root", + "password": "password", + } + else: + # External MySQL container + host = os.environ.get("DJ_HOST", "localhost") + port = os.environ.get("DJ_PORT", "3306") + return { + "backend": "mysql", + "host": f"{host}:{port}" if port else host, + "user": os.environ.get("DJ_USER", "root"), + "password": os.environ.get("DJ_PASS", "password"), + } + + elif backend == "postgresql": + if postgres_container is not None: + host = postgres_container.get_container_host_ip() + port = postgres_container.get_exposed_port(5432) + return { + "backend": "postgresql", + "host": f"{host}:{port}", + "user": "postgres", + "password": "password", + } + else: + # External PostgreSQL container + host = os.environ.get("DJ_PG_HOST", "localhost") + port = os.environ.get("DJ_PG_PORT", "5432") + return { + "backend": "postgresql", + "host": f"{host}:{port}" if port else host, + "user": os.environ.get("DJ_PG_USER", "postgres"), + "password": os.environ.get("DJ_PG_PASS", "password"), + } + + +@pytest.fixture(scope="function") +def connection_by_backend(db_creds_by_backend): + """Create connection for the specified backend. + + This fixture is function-scoped to ensure database.backend config + is restored after each test, preventing config pollution between tests. + """ + # Save original config to restore after tests + original_backend = dj.config.get("database.backend", "mysql") + original_host = dj.config.get("database.host") + original_port = dj.config.get("database.port") + + # Configure backend + dj.config["database.backend"] = db_creds_by_backend["backend"] + + # Parse host:port + host_port = db_creds_by_backend["host"] + if ":" in host_port: + host, port = host_port.rsplit(":", 1) + else: + host = host_port + port = "3306" if db_creds_by_backend["backend"] == "mysql" else "5432" + + dj.config["database.host"] = host + dj.config["database.port"] = int(port) + dj.config["safemode"] = False + + connection = dj.Connection( + host=host_port, + user=db_creds_by_backend["user"], + password=db_creds_by_backend["password"], + ) + + yield connection + + # Restore original config + connection.close() + dj.config["database.backend"] = original_backend + if original_host is not None: + dj.config["database.host"] = original_host + if original_port is not None: + dj.config["database.port"] = original_port + + # ============================================================================= # DataJoint Configuration # ============================================================================= diff --git a/tests/integration/test_cascade_delete.py b/tests/integration/test_cascade_delete.py new file mode 100644 index 000000000..caf5f331b --- /dev/null +++ b/tests/integration/test_cascade_delete.py @@ -0,0 +1,190 @@ +""" +Integration tests for cascade delete on multiple backends. +""" + +import pytest + +import datajoint as dj + + +@pytest.fixture(scope="function") +def schema_by_backend(connection_by_backend, db_creds_by_backend, request): + """Create a schema for cascade delete tests.""" + backend = db_creds_by_backend["backend"] + # Use unique schema name for each test + import time + + test_id = str(int(time.time() * 1000))[-8:] # Last 8 digits of timestamp + schema_name = f"djtest_cascade_{backend}_{test_id}"[:64] # Limit length + + # Drop schema if exists (cleanup from any previous failed runs) + if connection_by_backend.is_connected: + try: + connection_by_backend.query( + f"DROP DATABASE IF EXISTS {connection_by_backend.adapter.quote_identifier(schema_name)}" + ) + except Exception: + pass # Ignore errors during cleanup + + # Create fresh schema + schema = dj.Schema(schema_name, connection=connection_by_backend) + + yield schema + + # Cleanup after test + if connection_by_backend.is_connected: + try: + connection_by_backend.query( + f"DROP DATABASE IF EXISTS {connection_by_backend.adapter.quote_identifier(schema_name)}" + ) + except Exception: + pass # Ignore errors during cleanup + + +def test_simple_cascade_delete(schema_by_backend): + """Test basic cascade delete with foreign keys.""" + + @schema_by_backend + class Parent(dj.Manual): + definition = """ + parent_id : int + --- + name : varchar(255) + """ + + @schema_by_backend + class Child(dj.Manual): + definition = """ + -> Parent + child_id : int + --- + data : varchar(255) + """ + + # Insert test data + Parent.insert1((1, "Parent1")) + Parent.insert1((2, "Parent2")) + Child.insert1((1, 1, "Child1-1")) + Child.insert1((1, 2, "Child1-2")) + Child.insert1((2, 1, "Child2-1")) + + assert len(Parent()) == 2 + assert len(Child()) == 3 + + # Delete parent with cascade + (Parent & {"parent_id": 1}).delete() + + # Check cascade worked + assert len(Parent()) == 1 + assert len(Child()) == 1 + + # Verify remaining data (using to_dicts for DJ 2.0) + remaining = Child().to_dicts() + assert len(remaining) == 1 + assert remaining[0]["parent_id"] == 2 + assert remaining[0]["child_id"] == 1 + assert remaining[0]["data"] == "Child2-1" + + +def test_multi_level_cascade_delete(schema_by_backend): + """Test cascade delete through multiple levels of foreign keys.""" + + @schema_by_backend + class GrandParent(dj.Manual): + definition = """ + gp_id : int + --- + name : varchar(255) + """ + + @schema_by_backend + class Parent(dj.Manual): + definition = """ + -> GrandParent + parent_id : int + --- + name : varchar(255) + """ + + @schema_by_backend + class Child(dj.Manual): + definition = """ + -> Parent + child_id : int + --- + data : varchar(255) + """ + + # Insert test data + GrandParent.insert1((1, "GP1")) + Parent.insert1((1, 1, "P1")) + Parent.insert1((1, 2, "P2")) + Child.insert1((1, 1, 1, "C1")) + Child.insert1((1, 1, 2, "C2")) + Child.insert1((1, 2, 1, "C3")) + + assert len(GrandParent()) == 1 + assert len(Parent()) == 2 + assert len(Child()) == 3 + + # Delete grandparent - should cascade through parent to child + (GrandParent & {"gp_id": 1}).delete() + + # Check everything is deleted + assert len(GrandParent()) == 0 + assert len(Parent()) == 0 + assert len(Child()) == 0 + + # Verify all tables are empty + assert len(GrandParent().to_dicts()) == 0 + assert len(Parent().to_dicts()) == 0 + assert len(Child().to_dicts()) == 0 + + +def test_cascade_delete_with_renamed_attrs(schema_by_backend): + """Test cascade delete when foreign key renames attributes.""" + + @schema_by_backend + class Animal(dj.Manual): + definition = """ + animal_id : int + --- + species : varchar(255) + """ + + @schema_by_backend + class Observation(dj.Manual): + definition = """ + obs_id : int + --- + -> Animal.proj(subject_id='animal_id') + measurement : float + """ + + # Insert test data + Animal.insert1((1, "Mouse")) + Animal.insert1((2, "Rat")) + Observation.insert1((1, 1, 10.5)) + Observation.insert1((2, 1, 11.2)) + Observation.insert1((3, 2, 15.3)) + + assert len(Animal()) == 2 + assert len(Observation()) == 3 + + # Delete animal - should cascade to observations + (Animal & {"animal_id": 1}).delete() + + # Check cascade worked + assert len(Animal()) == 1 + assert len(Observation()) == 1 + + # Verify remaining data + remaining_animals = Animal().to_dicts() + assert len(remaining_animals) == 1 + assert remaining_animals[0]["animal_id"] == 2 + + remaining_obs = Observation().to_dicts() + assert len(remaining_obs) == 1 + assert remaining_obs[0]["obs_id"] == 3 + assert remaining_obs[0]["subject_id"] == 2 + assert remaining_obs[0]["measurement"] == 15.3 diff --git a/tests/integration/test_declare.py b/tests/integration/test_declare.py index 3097a9457..d38583cfd 100644 --- a/tests/integration/test_declare.py +++ b/tests/integration/test_declare.py @@ -44,27 +44,30 @@ def test_describe(schema_any): """real_definition should match original definition""" rel = Experiment() context = inspect.currentframe().f_globals - s1 = declare(rel.full_table_name, rel.definition, context) - s2 = declare(rel.full_table_name, rel.describe(), context) - assert s1 == s2 + adapter = rel.connection.adapter + s1 = declare(rel.full_table_name, rel.definition, context, adapter) + s2 = declare(rel.full_table_name, rel.describe(), context, adapter) + assert s1[0] == s2[0] # Compare SQL only (declare now returns tuple) def test_describe_indexes(schema_any): """real_definition should match original definition""" rel = IndexRich() context = inspect.currentframe().f_globals - s1 = declare(rel.full_table_name, rel.definition, context) - s2 = declare(rel.full_table_name, rel.describe(), context) - assert s1 == s2 + adapter = rel.connection.adapter + s1 = declare(rel.full_table_name, rel.definition, context, adapter) + s2 = declare(rel.full_table_name, rel.describe(), context, adapter) + assert s1[0] == s2[0] # Compare SQL only (declare now returns tuple) def test_describe_dependencies(schema_any): """real_definition should match original definition""" rel = ThingC() context = inspect.currentframe().f_globals - s1 = declare(rel.full_table_name, rel.definition, context) - s2 = declare(rel.full_table_name, rel.describe(), context) - assert s1 == s2 + adapter = rel.connection.adapter + s1 = declare(rel.full_table_name, rel.definition, context, adapter) + s2 = declare(rel.full_table_name, rel.describe(), context, adapter) + assert s1[0] == s2[0] # Compare SQL only (declare now returns tuple) def test_part(schema_any): @@ -365,3 +368,96 @@ class Table_With_Underscores(dj.Manual): schema_any(TableNoUnderscores) with pytest.raises(dj.DataJointError, match="must be alphanumeric in CamelCase"): schema_any(Table_With_Underscores) + + +class TestSingletonTables: + """Tests for singleton tables (empty primary keys).""" + + def test_singleton_declaration(self, schema_any): + """Singleton table creates correctly with hidden _singleton attribute.""" + + @schema_any + class Config(dj.Lookup): + definition = """ + # Global configuration + --- + setting : varchar(100) + """ + + # Access attributes first to trigger lazy loading from database + visible_attrs = Config.heading.attributes + all_attrs = Config.heading._attributes + + # Table should exist and have _singleton as hidden PK + assert "_singleton" in all_attrs + assert "_singleton" not in visible_attrs + assert Config.heading.primary_key == [] # Visible PK is empty for singleton + + def test_singleton_insert_and_fetch(self, schema_any): + """Insert and fetch work without specifying _singleton.""" + + @schema_any + class Settings(dj.Lookup): + definition = """ + --- + value : int32 + """ + + # Insert without specifying _singleton + Settings.insert1({"value": 42}) + + # Fetch should work + result = Settings.fetch1() + assert result["value"] == 42 + assert "_singleton" not in result # Hidden attribute excluded + + def test_singleton_uniqueness(self, schema_any): + """Second insert raises DuplicateError.""" + + @schema_any + class SingleValue(dj.Lookup): + definition = """ + --- + data : varchar(50) + """ + + SingleValue.insert1({"data": "first"}) + + # Second insert should fail + with pytest.raises(dj.errors.DuplicateError): + SingleValue.insert1({"data": "second"}) + + def test_singleton_with_multiple_attributes(self, schema_any): + """Singleton table with multiple secondary attributes.""" + + @schema_any + class PipelineConfig(dj.Lookup): + definition = """ + # Pipeline configuration singleton + --- + version : varchar(20) + max_workers : int32 + debug_mode : bool + """ + + PipelineConfig.insert1({"version": "1.0.0", "max_workers": 4, "debug_mode": False}) + + result = PipelineConfig.fetch1() + assert result["version"] == "1.0.0" + assert result["max_workers"] == 4 + assert result["debug_mode"] == 0 # bool stored as tinyint + + def test_singleton_describe(self, schema_any): + """Describe should show the singleton nature.""" + + @schema_any + class Metadata(dj.Lookup): + definition = """ + --- + info : varchar(255) + """ + + description = Metadata.describe() + # Description should show just the secondary attribute + assert "info" in description + # _singleton is hidden, implementation detail diff --git a/tests/integration/test_foreign_keys.py b/tests/integration/test_foreign_keys.py index 014340898..588c12cbf 100644 --- a/tests/integration/test_foreign_keys.py +++ b/tests/integration/test_foreign_keys.py @@ -31,8 +31,9 @@ def test_describe(schema_adv): """real_definition should match original definition""" for rel in (LocalSynapse, GlobalSynapse): describe = rel.describe() - s1 = declare(rel.full_table_name, rel.definition, schema_adv.context)[0].split("\n") - s2 = declare(rel.full_table_name, describe, globals())[0].split("\n") + adapter = rel.connection.adapter + s1 = declare(rel.full_table_name, rel.definition, schema_adv.context, adapter)[0].split("\n") + s2 = declare(rel.full_table_name, describe, globals(), adapter)[0].split("\n") for c1, c2 in zip(s1, s2): assert c1 == c2 diff --git a/tests/integration/test_json.py b/tests/integration/test_json.py index 40c8074de..97d0c73bf 100644 --- a/tests/integration/test_json.py +++ b/tests/integration/test_json.py @@ -122,9 +122,10 @@ def test_insert_update(schema_json): def test_describe(schema_json): rel = Team() context = inspect.currentframe().f_globals - s1 = declare(rel.full_table_name, rel.definition, context) - s2 = declare(rel.full_table_name, rel.describe(), context) - assert s1 == s2 + adapter = rel.connection.adapter + s1 = declare(rel.full_table_name, rel.definition, context, adapter) + s2 = declare(rel.full_table_name, rel.describe(), context, adapter) + assert s1[0] == s2[0] # Compare SQL only (declare now returns tuple) def test_restrict(schema_json): diff --git a/tests/integration/test_multi_backend.py b/tests/integration/test_multi_backend.py new file mode 100644 index 000000000..bf904e362 --- /dev/null +++ b/tests/integration/test_multi_backend.py @@ -0,0 +1,143 @@ +""" +Integration tests that verify backend-agnostic behavior. + +These tests run against both MySQL and PostgreSQL to ensure: +1. DDL generation is correct +2. SQL queries work identically +3. Data types map correctly + +To run these tests: + pytest tests/integration/test_multi_backend.py # Run against both backends + pytest -m "mysql" tests/integration/test_multi_backend.py # MySQL only + pytest -m "postgresql" tests/integration/test_multi_backend.py # PostgreSQL only +""" + +import pytest +import datajoint as dj + + +@pytest.mark.backend_agnostic +def test_simple_table_declaration(connection_by_backend, backend, prefix): + """Test that simple tables can be declared on both backends.""" + schema = dj.Schema( + f"{prefix}_multi_backend_{backend}_simple", + connection=connection_by_backend, + ) + + @schema + class User(dj.Manual): + definition = """ + user_id : int + --- + username : varchar(255) + created_at : datetime + """ + + # Verify table exists + assert User.is_declared + + # Insert and fetch data + from datetime import datetime + + User.insert1((1, "alice", datetime(2025, 1, 1))) + data = User.fetch1() + + assert data["user_id"] == 1 + assert data["username"] == "alice" + + # Cleanup + schema.drop() + + +@pytest.mark.backend_agnostic +def test_foreign_keys(connection_by_backend, backend, prefix): + """Test foreign key declarations work on both backends.""" + schema = dj.Schema( + f"{prefix}_multi_backend_{backend}_fk", + connection=connection_by_backend, + ) + + @schema + class Animal(dj.Manual): + definition = """ + animal_id : int + --- + name : varchar(255) + """ + + @schema + class Observation(dj.Manual): + definition = """ + -> Animal + obs_id : int + --- + notes : varchar(1000) + """ + + # Insert data + Animal.insert1((1, "Mouse")) + Observation.insert1((1, 1, "Active")) + + # Verify data was inserted + assert len(Animal()) == 1 + assert len(Observation()) == 1 + + # Cleanup + schema.drop() + + +@pytest.mark.backend_agnostic +def test_data_types(connection_by_backend, backend, prefix): + """Test that core data types work on both backends.""" + schema = dj.Schema( + f"{prefix}_multi_backend_{backend}_types", + connection=connection_by_backend, + ) + + @schema + class TypeTest(dj.Manual): + definition = """ + id : int + --- + int_value : int + str_value : varchar(255) + float_value : float + bool_value : bool + """ + + # Insert data + TypeTest.insert1((1, 42, "test", 3.14, True)) + + # Fetch and verify + data = (TypeTest & {"id": 1}).fetch1() + assert data["int_value"] == 42 + assert data["str_value"] == "test" + assert abs(data["float_value"] - 3.14) < 0.001 + assert data["bool_value"] == 1 # MySQL stores as tinyint(1) + + # Cleanup + schema.drop() + + +@pytest.mark.backend_agnostic +def test_table_comments(connection_by_backend, backend, prefix): + """Test that table comments are preserved on both backends.""" + schema = dj.Schema( + f"{prefix}_multi_backend_{backend}_comments", + connection=connection_by_backend, + ) + + @schema + class Commented(dj.Manual): + definition = """ + # This is a test table for backend testing + id : int # primary key + --- + value : varchar(255) # some value + """ + + # Verify table was created + assert Commented.is_declared + + # Cleanup + schema.drop() diff --git a/tests/integration/test_schema.py b/tests/integration/test_schema.py index 6fcaffc6d..ef621765d 100644 --- a/tests/integration/test_schema.py +++ b/tests/integration/test_schema.py @@ -62,7 +62,7 @@ def test_schema_size_on_disk(schema_any): def test_schema_list(schema_any): - schemas = dj.list_schemas() + schemas = dj.list_schemas(connection=schema_any.connection) assert schema_any.database in schemas diff --git a/tests/integration/test_tls.py b/tests/integration/test_tls.py index e46825227..19ed087b7 100644 --- a/tests/integration/test_tls.py +++ b/tests/integration/test_tls.py @@ -1,20 +1,51 @@ +import logging +import os + import pytest from pymysql.err import OperationalError import datajoint as dj +# SSL tests require docker-compose with datajoint/mysql image (has SSL configured) +# Testcontainers with official mysql image doesn't have SSL certificates +requires_ssl = pytest.mark.skipif( + os.environ.get("DJ_USE_EXTERNAL_CONTAINERS", "").lower() not in ("1", "true", "yes"), + reason="SSL tests require external containers (docker-compose) with SSL configured", +) + + +@requires_ssl +def test_explicit_ssl_connection(db_creds_test, connection_test): + """When use_tls=True is specified, SSL must be active.""" + result = dj.conn(use_tls=True, reset=True, **db_creds_test).query("SHOW STATUS LIKE 'Ssl_cipher';").fetchone()[1] + assert len(result) > 0, "SSL should be active when use_tls=True" + + +@requires_ssl +def test_ssl_auto_detect(db_creds_test, connection_test, caplog): + """When use_tls is not specified, SSL is preferred but fallback is allowed with warning.""" + with caplog.at_level(logging.WARNING): + conn = dj.conn(reset=True, **db_creds_test) + result = conn.query("SHOW STATUS LIKE 'Ssl_cipher';").fetchone()[1] -def test_secure_connection(db_creds_test, connection_test): - result = dj.conn(reset=True, **db_creds_test).query("SHOW STATUS LIKE 'Ssl_cipher';").fetchone()[1] - assert len(result) > 0 + if len(result) > 0: + # SSL connected successfully + assert "SSL connection failed" not in caplog.text + else: + # SSL failed and fell back - warning should be logged + assert "SSL connection failed" in caplog.text + assert "Falling back to non-SSL" in caplog.text def test_insecure_connection(db_creds_test, connection_test): + """When use_tls=False, SSL should not be used.""" result = dj.conn(use_tls=False, reset=True, **db_creds_test).query("SHOW STATUS LIKE 'Ssl_cipher';").fetchone()[1] assert result == "" +@requires_ssl def test_reject_insecure(db_creds_test, connection_test): + """Users with REQUIRE SSL cannot connect without SSL.""" with pytest.raises(OperationalError): dj.conn( db_creds_test["host"], diff --git a/tests/unit/test_adapters.py b/tests/unit/test_adapters.py new file mode 100644 index 000000000..edbff9d52 --- /dev/null +++ b/tests/unit/test_adapters.py @@ -0,0 +1,544 @@ +""" +Unit tests for database adapters. + +Tests adapter functionality without requiring actual database connections. +""" + +import pytest + +from datajoint.adapters import DatabaseAdapter, MySQLAdapter, PostgreSQLAdapter, get_adapter + + +class TestAdapterRegistry: + """Test adapter registry and factory function.""" + + def test_get_adapter_mysql(self): + """Test getting MySQL adapter.""" + adapter = get_adapter("mysql") + assert isinstance(adapter, MySQLAdapter) + assert isinstance(adapter, DatabaseAdapter) + + def test_get_adapter_postgresql(self): + """Test getting PostgreSQL adapter.""" + pytest.importorskip("psycopg2") + adapter = get_adapter("postgresql") + assert isinstance(adapter, PostgreSQLAdapter) + assert isinstance(adapter, DatabaseAdapter) + + def test_get_adapter_postgres_alias(self): + """Test 'postgres' alias for PostgreSQL.""" + pytest.importorskip("psycopg2") + adapter = get_adapter("postgres") + assert isinstance(adapter, PostgreSQLAdapter) + + def test_get_adapter_case_insensitive(self): + """Test case-insensitive backend names.""" + assert isinstance(get_adapter("MySQL"), MySQLAdapter) + # Only test PostgreSQL if psycopg2 is available + try: + pytest.importorskip("psycopg2") + assert isinstance(get_adapter("POSTGRESQL"), PostgreSQLAdapter) + assert isinstance(get_adapter("PoStGrEs"), PostgreSQLAdapter) + except pytest.skip.Exception: + pass # Skip PostgreSQL tests if psycopg2 not available + + def test_get_adapter_invalid(self): + """Test error on invalid backend name.""" + with pytest.raises(ValueError, match="Unknown database backend"): + get_adapter("sqlite") + + +class TestMySQLAdapter: + """Test MySQL adapter implementation.""" + + @pytest.fixture + def adapter(self): + """MySQL adapter instance.""" + return MySQLAdapter() + + def test_default_port(self, adapter): + """Test MySQL default port is 3306.""" + assert adapter.default_port == 3306 + + def test_parameter_placeholder(self, adapter): + """Test MySQL parameter placeholder is %s.""" + assert adapter.parameter_placeholder == "%s" + + def test_quote_identifier(self, adapter): + """Test identifier quoting with backticks.""" + assert adapter.quote_identifier("table_name") == "`table_name`" + assert adapter.quote_identifier("my_column") == "`my_column`" + + def test_quote_string(self, adapter): + """Test string literal quoting.""" + assert "test" in adapter.quote_string("test") + # Should handle escaping + result = adapter.quote_string("It's a test") + assert "It" in result + + def test_core_type_to_sql_simple(self, adapter): + """Test core type mapping for simple types.""" + assert adapter.core_type_to_sql("int64") == "bigint" + assert adapter.core_type_to_sql("int32") == "int" + assert adapter.core_type_to_sql("int16") == "smallint" + assert adapter.core_type_to_sql("int8") == "tinyint" + assert adapter.core_type_to_sql("float32") == "float" + assert adapter.core_type_to_sql("float64") == "double" + assert adapter.core_type_to_sql("bool") == "tinyint" + assert adapter.core_type_to_sql("uuid") == "binary(16)" + assert adapter.core_type_to_sql("bytes") == "longblob" + assert adapter.core_type_to_sql("json") == "json" + assert adapter.core_type_to_sql("date") == "date" + + def test_core_type_to_sql_parametrized(self, adapter): + """Test core type mapping for parametrized types.""" + assert adapter.core_type_to_sql("datetime") == "datetime" + assert adapter.core_type_to_sql("datetime(3)") == "datetime(3)" + assert adapter.core_type_to_sql("char(10)") == "char(10)" + assert adapter.core_type_to_sql("varchar(255)") == "varchar(255)" + assert adapter.core_type_to_sql("decimal(10,2)") == "decimal(10,2)" + assert adapter.core_type_to_sql("enum('a','b','c')") == "enum('a','b','c')" + + def test_core_type_to_sql_invalid(self, adapter): + """Test error on invalid core type.""" + with pytest.raises(ValueError, match="Unknown core type"): + adapter.core_type_to_sql("invalid_type") + + def test_sql_type_to_core(self, adapter): + """Test reverse type mapping.""" + assert adapter.sql_type_to_core("bigint") == "int64" + assert adapter.sql_type_to_core("int") == "int32" + assert adapter.sql_type_to_core("float") == "float32" + assert adapter.sql_type_to_core("double") == "float64" + assert adapter.sql_type_to_core("longblob") == "bytes" + assert adapter.sql_type_to_core("datetime(3)") == "datetime(3)" + # Unmappable types return None + assert adapter.sql_type_to_core("mediumint") is None + + def test_create_schema_sql(self, adapter): + """Test CREATE DATABASE statement.""" + sql = adapter.create_schema_sql("test_db") + assert sql == "CREATE DATABASE `test_db`" + + def test_drop_schema_sql(self, adapter): + """Test DROP DATABASE statement.""" + sql = adapter.drop_schema_sql("test_db") + assert "DROP DATABASE" in sql + assert "IF EXISTS" in sql + assert "`test_db`" in sql + + def test_insert_sql_basic(self, adapter): + """Test basic INSERT statement.""" + sql = adapter.insert_sql("users", ["id", "name"]) + assert sql == "INSERT INTO users (`id`, `name`) VALUES (%s, %s)" + + def test_insert_sql_ignore(self, adapter): + """Test INSERT IGNORE statement.""" + sql = adapter.insert_sql("users", ["id", "name"], on_duplicate="ignore") + assert "INSERT IGNORE" in sql + + def test_insert_sql_replace(self, adapter): + """Test REPLACE INTO statement.""" + sql = adapter.insert_sql("users", ["id"], on_duplicate="replace") + assert "REPLACE INTO" in sql + + def test_insert_sql_update(self, adapter): + """Test INSERT ... ON DUPLICATE KEY UPDATE statement.""" + sql = adapter.insert_sql("users", ["id", "name"], on_duplicate="update") + assert "INSERT INTO" in sql + assert "ON DUPLICATE KEY UPDATE" in sql + + def test_update_sql(self, adapter): + """Test UPDATE statement.""" + sql = adapter.update_sql("users", ["name"], ["id"]) + assert "UPDATE users SET" in sql + assert "`name` = %s" in sql + assert "WHERE" in sql + assert "`id` = %s" in sql + + def test_delete_sql(self, adapter): + """Test DELETE statement.""" + sql = adapter.delete_sql("users") + assert sql == "DELETE FROM users" + + def test_current_timestamp_expr(self, adapter): + """Test CURRENT_TIMESTAMP expression.""" + assert adapter.current_timestamp_expr() == "CURRENT_TIMESTAMP" + assert adapter.current_timestamp_expr(3) == "CURRENT_TIMESTAMP(3)" + + def test_interval_expr(self, adapter): + """Test INTERVAL expression.""" + assert adapter.interval_expr(5, "second") == "INTERVAL 5 SECOND" + assert adapter.interval_expr(10, "minute") == "INTERVAL 10 MINUTE" + + def test_json_path_expr(self, adapter): + """Test JSON path extraction.""" + assert adapter.json_path_expr("data", "field") == "json_value(`data`, _utf8mb4'$.field')" + assert adapter.json_path_expr("record", "nested") == "json_value(`record`, _utf8mb4'$.nested')" + + def test_json_path_expr_with_return_type(self, adapter): + """Test JSON path extraction with return type.""" + result = adapter.json_path_expr("data", "value", "decimal(10,2)") + assert result == "json_value(`data`, _utf8mb4'$.value' returning decimal(10,2))" + + def test_transaction_sql(self, adapter): + """Test transaction statements.""" + assert "START TRANSACTION" in adapter.start_transaction_sql() + assert adapter.commit_sql() == "COMMIT" + assert adapter.rollback_sql() == "ROLLBACK" + + def test_validate_native_type(self, adapter): + """Test native type validation.""" + assert adapter.validate_native_type("int") + assert adapter.validate_native_type("bigint") + assert adapter.validate_native_type("varchar(255)") + assert adapter.validate_native_type("text") + assert adapter.validate_native_type("json") + assert not adapter.validate_native_type("invalid_type") + + +class TestPostgreSQLAdapter: + """Test PostgreSQL adapter implementation.""" + + @pytest.fixture + def adapter(self): + """PostgreSQL adapter instance.""" + # Skip if psycopg2 not installed + pytest.importorskip("psycopg2") + return PostgreSQLAdapter() + + def test_default_port(self, adapter): + """Test PostgreSQL default port is 5432.""" + assert adapter.default_port == 5432 + + def test_parameter_placeholder(self, adapter): + """Test PostgreSQL parameter placeholder is %s.""" + assert adapter.parameter_placeholder == "%s" + + def test_quote_identifier(self, adapter): + """Test identifier quoting with double quotes.""" + assert adapter.quote_identifier("table_name") == '"table_name"' + assert adapter.quote_identifier("my_column") == '"my_column"' + + def test_quote_string(self, adapter): + """Test string literal quoting.""" + assert adapter.quote_string("test") == "'test'" + # PostgreSQL doubles single quotes for escaping + assert adapter.quote_string("It's a test") == "'It''s a test'" + + def test_core_type_to_sql_simple(self, adapter): + """Test core type mapping for simple types.""" + assert adapter.core_type_to_sql("int64") == "bigint" + assert adapter.core_type_to_sql("int32") == "integer" + assert adapter.core_type_to_sql("int16") == "smallint" + assert adapter.core_type_to_sql("int8") == "smallint" # No tinyint in PostgreSQL + assert adapter.core_type_to_sql("float32") == "real" + assert adapter.core_type_to_sql("float64") == "double precision" + assert adapter.core_type_to_sql("bool") == "boolean" + assert adapter.core_type_to_sql("uuid") == "uuid" + assert adapter.core_type_to_sql("bytes") == "bytea" + assert adapter.core_type_to_sql("json") == "jsonb" + assert adapter.core_type_to_sql("date") == "date" + + def test_core_type_to_sql_parametrized(self, adapter): + """Test core type mapping for parametrized types.""" + assert adapter.core_type_to_sql("datetime") == "timestamp" + assert adapter.core_type_to_sql("datetime(3)") == "timestamp(3)" + assert adapter.core_type_to_sql("char(10)") == "char(10)" + assert adapter.core_type_to_sql("varchar(255)") == "varchar(255)" + assert adapter.core_type_to_sql("decimal(10,2)") == "numeric(10,2)" + + def test_sql_type_to_core(self, adapter): + """Test reverse type mapping.""" + assert adapter.sql_type_to_core("bigint") == "int64" + assert adapter.sql_type_to_core("integer") == "int32" + assert adapter.sql_type_to_core("real") == "float32" + assert adapter.sql_type_to_core("double precision") == "float64" + assert adapter.sql_type_to_core("boolean") == "bool" + assert adapter.sql_type_to_core("uuid") == "uuid" + assert adapter.sql_type_to_core("bytea") == "bytes" + assert adapter.sql_type_to_core("jsonb") == "json" + assert adapter.sql_type_to_core("timestamp") == "datetime" + assert adapter.sql_type_to_core("timestamp(3)") == "datetime(3)" + assert adapter.sql_type_to_core("numeric(10,2)") == "decimal(10,2)" + + def test_create_schema_sql(self, adapter): + """Test CREATE SCHEMA statement.""" + sql = adapter.create_schema_sql("test_schema") + assert sql == 'CREATE SCHEMA "test_schema"' + + def test_drop_schema_sql(self, adapter): + """Test DROP SCHEMA statement.""" + sql = adapter.drop_schema_sql("test_schema") + assert "DROP SCHEMA" in sql + assert "IF EXISTS" in sql + assert '"test_schema"' in sql + assert "CASCADE" in sql + + def test_insert_sql_basic(self, adapter): + """Test basic INSERT statement.""" + sql = adapter.insert_sql("users", ["id", "name"]) + assert sql == 'INSERT INTO users ("id", "name") VALUES (%s, %s)' + + def test_insert_sql_ignore(self, adapter): + """Test INSERT ... ON CONFLICT DO NOTHING statement.""" + sql = adapter.insert_sql("users", ["id", "name"], on_duplicate="ignore") + assert "INSERT INTO" in sql + assert "ON CONFLICT DO NOTHING" in sql + + def test_insert_sql_update(self, adapter): + """Test INSERT ... ON CONFLICT DO UPDATE statement.""" + sql = adapter.insert_sql("users", ["id", "name"], on_duplicate="update") + assert "INSERT INTO" in sql + assert "ON CONFLICT DO UPDATE" in sql + assert "EXCLUDED" in sql + + def test_update_sql(self, adapter): + """Test UPDATE statement.""" + sql = adapter.update_sql("users", ["name"], ["id"]) + assert "UPDATE users SET" in sql + assert '"name" = %s' in sql + assert "WHERE" in sql + assert '"id" = %s' in sql + + def test_delete_sql(self, adapter): + """Test DELETE statement.""" + sql = adapter.delete_sql("users") + assert sql == "DELETE FROM users" + + def test_current_timestamp_expr(self, adapter): + """Test CURRENT_TIMESTAMP expression.""" + assert adapter.current_timestamp_expr() == "CURRENT_TIMESTAMP" + assert adapter.current_timestamp_expr(3) == "CURRENT_TIMESTAMP(3)" + + def test_interval_expr(self, adapter): + """Test INTERVAL expression with PostgreSQL syntax.""" + assert adapter.interval_expr(5, "second") == "INTERVAL '5 seconds'" + assert adapter.interval_expr(10, "minute") == "INTERVAL '10 minutes'" + + def test_json_path_expr(self, adapter): + """Test JSON path extraction for PostgreSQL.""" + assert adapter.json_path_expr("data", "field") == "jsonb_extract_path_text(\"data\", 'field')" + assert adapter.json_path_expr("record", "name") == "jsonb_extract_path_text(\"record\", 'name')" + + def test_json_path_expr_nested(self, adapter): + """Test JSON path extraction with nested paths.""" + result = adapter.json_path_expr("data", "nested.field") + assert result == "jsonb_extract_path_text(\"data\", 'nested', 'field')" + + def test_transaction_sql(self, adapter): + """Test transaction statements.""" + assert adapter.start_transaction_sql() == "BEGIN" + assert adapter.commit_sql() == "COMMIT" + assert adapter.rollback_sql() == "ROLLBACK" + + def test_validate_native_type(self, adapter): + """Test native type validation.""" + assert adapter.validate_native_type("integer") + assert adapter.validate_native_type("bigint") + assert adapter.validate_native_type("varchar") + assert adapter.validate_native_type("text") + assert adapter.validate_native_type("jsonb") + assert adapter.validate_native_type("uuid") + assert adapter.validate_native_type("boolean") + assert not adapter.validate_native_type("invalid_type") + + def test_enum_type_sql(self, adapter): + """Test PostgreSQL enum type creation.""" + sql = adapter.create_enum_type_sql("myschema", "mytable", "status", ["pending", "complete"]) + assert "CREATE TYPE" in sql + assert "myschema_mytable_status_enum" in sql + assert "AS ENUM" in sql + assert "'pending'" in sql + assert "'complete'" in sql + + def test_drop_enum_type_sql(self, adapter): + """Test PostgreSQL enum type dropping.""" + sql = adapter.drop_enum_type_sql("myschema", "mytable", "status") + assert "DROP TYPE" in sql + assert "IF EXISTS" in sql + assert "myschema_mytable_status_enum" in sql + assert "CASCADE" in sql + + +class TestAdapterInterface: + """Test that adapters implement the full interface.""" + + @pytest.mark.parametrize("backend", ["mysql", "postgresql"]) + def test_adapter_implements_interface(self, backend): + """Test that adapter implements all abstract methods.""" + if backend == "postgresql": + pytest.importorskip("psycopg2") + + adapter = get_adapter(backend) + + # Check that all abstract methods are implemented (not abstract) + abstract_methods = [ + "connect", + "close", + "ping", + "get_connection_id", + "quote_identifier", + "quote_string", + "core_type_to_sql", + "sql_type_to_core", + "create_schema_sql", + "drop_schema_sql", + "create_table_sql", + "drop_table_sql", + "alter_table_sql", + "add_comment_sql", + "insert_sql", + "update_sql", + "delete_sql", + "list_schemas_sql", + "list_tables_sql", + "get_table_info_sql", + "get_columns_sql", + "get_primary_key_sql", + "get_foreign_keys_sql", + "get_indexes_sql", + "parse_column_info", + "start_transaction_sql", + "commit_sql", + "rollback_sql", + "current_timestamp_expr", + "interval_expr", + "json_path_expr", + "format_column_definition", + "table_options_clause", + "table_comment_ddl", + "column_comment_ddl", + "enum_type_ddl", + "job_metadata_columns", + "translate_error", + "validate_native_type", + ] + + for method_name in abstract_methods: + assert hasattr(adapter, method_name), f"Adapter missing method: {method_name}" + method = getattr(adapter, method_name) + assert callable(method), f"Adapter.{method_name} is not callable" + + # Check properties + assert hasattr(adapter, "default_port") + assert isinstance(adapter.default_port, int) + assert hasattr(adapter, "parameter_placeholder") + assert isinstance(adapter.parameter_placeholder, str) + + +class TestDDLMethods: + """Test DDL generation adapter methods.""" + + @pytest.fixture + def adapter(self): + """MySQL adapter instance.""" + return MySQLAdapter() + + def test_format_column_definition_mysql(self, adapter): + """Test MySQL column definition formatting.""" + result = adapter.format_column_definition("user_id", "bigint", nullable=False, comment="user ID") + assert result == '`user_id` bigint NOT NULL COMMENT "user ID"' + + # Test without comment + result = adapter.format_column_definition("name", "varchar(255)", nullable=False) + assert result == "`name` varchar(255) NOT NULL" + + # Test nullable + result = adapter.format_column_definition("description", "text", nullable=True) + assert result == "`description` text" + + # Test with default + result = adapter.format_column_definition("status", "int", default="DEFAULT 1") + assert result == "`status` int DEFAULT 1" + + def test_table_options_clause_mysql(self, adapter): + """Test MySQL table options clause.""" + result = adapter.table_options_clause("test table") + assert result == 'ENGINE=InnoDB, COMMENT "test table"' + + result = adapter.table_options_clause() + assert result == "ENGINE=InnoDB" + + def test_table_comment_ddl_mysql(self, adapter): + """Test MySQL table comment DDL (should be None).""" + result = adapter.table_comment_ddl("`schema`.`table`", "test comment") + assert result is None + + def test_column_comment_ddl_mysql(self, adapter): + """Test MySQL column comment DDL (should be None).""" + result = adapter.column_comment_ddl("`schema`.`table`", "column", "test comment") + assert result is None + + def test_enum_type_ddl_mysql(self, adapter): + """Test MySQL enum type DDL (should be None).""" + result = adapter.enum_type_ddl("status_type", ["active", "inactive"]) + assert result is None + + def test_job_metadata_columns_mysql(self, adapter): + """Test MySQL job metadata columns.""" + result = adapter.job_metadata_columns() + assert len(result) == 3 + assert "_job_start_time" in result[0] + assert "datetime(3)" in result[0] + assert "_job_duration" in result[1] + assert "float" in result[1] + assert "_job_version" in result[2] + assert "varchar(64)" in result[2] + + +class TestPostgreSQLDDLMethods: + """Test PostgreSQL-specific DDL generation methods.""" + + @pytest.fixture + def postgres_adapter(self): + """Get PostgreSQL adapter for testing.""" + pytest.importorskip("psycopg2") + return get_adapter("postgresql") + + def test_format_column_definition_postgres(self, postgres_adapter): + """Test PostgreSQL column definition formatting.""" + result = postgres_adapter.format_column_definition("user_id", "bigint", nullable=False, comment="user ID") + assert result == '"user_id" bigint NOT NULL' + + # Test without comment (comment handled separately in PostgreSQL) + result = postgres_adapter.format_column_definition("name", "varchar(255)", nullable=False) + assert result == '"name" varchar(255) NOT NULL' + + # Test nullable + result = postgres_adapter.format_column_definition("description", "text", nullable=True) + assert result == '"description" text' + + def test_table_options_clause_postgres(self, postgres_adapter): + """Test PostgreSQL table options clause (should be empty).""" + result = postgres_adapter.table_options_clause("test table") + assert result == "" + + result = postgres_adapter.table_options_clause() + assert result == "" + + def test_table_comment_ddl_postgres(self, postgres_adapter): + """Test PostgreSQL table comment DDL.""" + result = postgres_adapter.table_comment_ddl('"schema"."table"', "test comment") + assert result == 'COMMENT ON TABLE "schema"."table" IS \'test comment\'' + + def test_column_comment_ddl_postgres(self, postgres_adapter): + """Test PostgreSQL column comment DDL.""" + result = postgres_adapter.column_comment_ddl('"schema"."table"', "column", "test comment") + assert result == 'COMMENT ON COLUMN "schema"."table"."column" IS \'test comment\'' + + def test_enum_type_ddl_postgres(self, postgres_adapter): + """Test PostgreSQL enum type DDL.""" + result = postgres_adapter.enum_type_ddl("status_type", ["active", "inactive"]) + assert result == "CREATE TYPE \"status_type\" AS ENUM ('active', 'inactive')" + + def test_job_metadata_columns_postgres(self, postgres_adapter): + """Test PostgreSQL job metadata columns.""" + result = postgres_adapter.job_metadata_columns() + assert len(result) == 3 + assert "_job_start_time" in result[0] + assert "timestamp" in result[0] + assert "_job_duration" in result[1] + assert "real" in result[1] + assert "_job_version" in result[2] + assert "varchar(64)" in result[2] diff --git a/tests/unit/test_settings.py b/tests/unit/test_settings.py index 61f4439e0..af5718503 100644 --- a/tests/unit/test_settings.py +++ b/tests/unit/test_settings.py @@ -748,3 +748,123 @@ def test_similar_prefix_names_allowed(self): finally: dj.config.stores.clear() dj.config.stores.update(original_stores) + + +class TestBackendConfiguration: + """Test database backend configuration and port auto-detection.""" + + def test_backend_default(self): + """Test default backend is mysql.""" + from datajoint.settings import DatabaseSettings + + settings = DatabaseSettings() + assert settings.backend == "mysql" + assert settings.port == 3306 + + def test_backend_postgresql(self, monkeypatch): + """Test PostgreSQL backend with auto port.""" + from datajoint.settings import DatabaseSettings + + monkeypatch.setenv("DJ_BACKEND", "postgresql") + settings = DatabaseSettings() + assert settings.backend == "postgresql" + assert settings.port == 5432 + + def test_backend_explicit_port_overrides(self, monkeypatch): + """Test explicit port overrides auto-detection.""" + from datajoint.settings import DatabaseSettings + + monkeypatch.setenv("DJ_BACKEND", "postgresql") + monkeypatch.setenv("DJ_PORT", "9999") + settings = DatabaseSettings() + assert settings.backend == "postgresql" + assert settings.port == 9999 + + def test_backend_env_var(self, monkeypatch): + """Test DJ_BACKEND environment variable.""" + from datajoint.settings import DatabaseSettings + + monkeypatch.setenv("DJ_BACKEND", "postgresql") + settings = DatabaseSettings() + assert settings.backend == "postgresql" + assert settings.port == 5432 + + def test_port_env_var_overrides_backend_default(self, monkeypatch): + """Test DJ_PORT overrides backend auto-detection.""" + from datajoint.settings import DatabaseSettings + + monkeypatch.setenv("DJ_BACKEND", "postgresql") + monkeypatch.setenv("DJ_PORT", "8888") + settings = DatabaseSettings() + assert settings.backend == "postgresql" + assert settings.port == 8888 + + def test_invalid_backend(self, monkeypatch): + """Test invalid backend raises validation error.""" + from datajoint.settings import DatabaseSettings + + monkeypatch.setenv("DJ_BACKEND", "sqlite") + with pytest.raises(ValidationError, match="Input should be 'mysql' or 'postgresql'"): + DatabaseSettings() + + def test_config_file_backend(self, tmp_path, monkeypatch): + """Test loading backend from config file.""" + import json + + from datajoint.settings import Config + + # Include port in config since auto-detection only happens during initialization + config_file = tmp_path / "test_config.json" + config_file.write_text(json.dumps({"database": {"backend": "postgresql", "host": "db.example.com", "port": 5432}})) + + # Clear env vars so file values take effect + monkeypatch.delenv("DJ_BACKEND", raising=False) + monkeypatch.delenv("DJ_HOST", raising=False) + monkeypatch.delenv("DJ_PORT", raising=False) + + cfg = Config() + cfg.load(config_file) + assert cfg.database.backend == "postgresql" + assert cfg.database.port == 5432 + assert cfg.database.host == "db.example.com" + + def test_global_config_backend(self): + """Test global config has backend configuration.""" + # Global config should have backend field with default mysql + assert hasattr(dj.config.database, "backend") + # Backend should be one of the valid values + assert dj.config.database.backend in ["mysql", "postgresql"] + # Port should be set (either 3306 or 5432 or custom) + assert isinstance(dj.config.database.port, int) + assert 1 <= dj.config.database.port <= 65535 + + def test_port_auto_detection_on_initialization(self): + """Test port auto-detects only during initialization, not on live updates.""" + from datajoint.settings import DatabaseSettings + + # Start with MySQL (default) + settings = DatabaseSettings() + assert settings.port == 3306 + + # Change backend on live config - port won't auto-update + settings.backend = "postgresql" + # Port remains at previous value (this is expected behavior) + # Users should set port explicitly when changing backend on live config + assert settings.port == 3306 # Didn't auto-update + + def test_mysql_backend_with_explicit_port(self, monkeypatch): + """Test MySQL backend with explicit non-default port.""" + from datajoint.settings import DatabaseSettings + + monkeypatch.setenv("DJ_BACKEND", "mysql") + monkeypatch.setenv("DJ_PORT", "3307") + settings = DatabaseSettings() + assert settings.backend == "mysql" + assert settings.port == 3307 + + def test_backend_field_in_env_var_mapping(self): + """Test that backend is mapped to DJ_BACKEND in ENV_VAR_MAPPING.""" + from datajoint.settings import ENV_VAR_MAPPING + + assert "database.backend" in ENV_VAR_MAPPING + assert ENV_VAR_MAPPING["database.backend"] == "DJ_BACKEND"