diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4746b0f4..be233755 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -42,7 +42,8 @@ jobs: run: uv tool run ty check --python-version 3.10 - name: Build the stack - run: docker compose up -d --wait mysql postgres presto trino clickhouse + # --profile full unlocks profile-gated clickhouse; only named services start + run: docker compose --profile full up -d --wait mysql postgres clickhouse - name: Run tests env: diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..b7e636f8 --- /dev/null +++ b/Makefile @@ -0,0 +1,45 @@ +.PHONY: up up-full down test-unit test demo + +## Start PostgreSQL + MySQL (lightweight, fast startup) +up: + docker compose up -d --wait postgres mysql + +## Start all services including ClickHouse, Presto, Trino, Vertica +up-full: + docker compose --profile full up -d --wait + +## Stop all services and remove volumes +down: + docker compose --profile full down -v + +## Run unit tests (no database required) +test-unit: + uv run pytest tests/test_query.py tests/test_utils.py -x + +## Run full test suite against PG + MySQL (starts containers if needed) +## To also test Presto/Trino/Vertica, run `make up-full` first and set: +## export DATADIFF_PRESTO_URI="presto://test@localhost:8080/memory/default" +## export DATADIFF_TRINO_URI="trino://test@localhost:8081/memory/default" +## export DATADIFF_VERTICA_URI="vertica://vertica:Password1@localhost:5433/vertica" +test: up + uv run pytest tests/ \ + -o addopts="--timeout=300 --tb=short" \ + --ignore=tests/test_database_types.py \ + --ignore=tests/test_dbt_config_validators.py \ + --ignore=tests/test_main.py + +## Run data-diff against seed data to showcase diffing +demo: up + @echo "=== PostgreSQL: ratings_source vs ratings_target ===" + uv run python -m data_diff \ + postgresql://postgres:Password1@localhost/postgres \ + ratings_source ratings_target \ + --key-columns id \ + --columns rating + @echo "" + @echo "=== MySQL: ratings_source vs ratings_target ===" + uv run python -m data_diff \ + mysql://mysql:Password1@localhost/mysql \ + ratings_source ratings_target \ + --key-columns id \ + --columns rating diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index 5211be6d..46eeedd1 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -25,7 +25,6 @@ CHECKSUM_HEXDIGITS, CHECKSUM_OFFSET, MD5_HEXDIGITS, - TIMESTAMP_PRECISION_POS, BaseDialect, ConnectError, ThreadedDatabase, @@ -115,8 +114,14 @@ def md5_as_hex(self, s: str) -> str: return f"md5({s})" def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - def _add_padding(coltype: TemporalType, timestamp6: str): - return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS + coltype.precision}), {TIMESTAMP_PRECISION_POS + 6}, '0')" + def _truncate_and_pad(coltype: TemporalType, timestamp6: str): + """Truncate a 6-digit-precision timestamp to target precision, then zero-pad back to 6 digits.""" + truncated = f"LEFT({timestamp6}, length({timestamp6}) - (6 - {coltype.precision}))" + return f"RPAD({truncated}, length({timestamp6}), '0')" + + def _zero_pad(coltype: TemporalType, already_truncated: str): + """Zero-pad an already-truncated timestamp back to 6 fractional digits.""" + return f"RPAD({already_truncated}, length({already_truncated}) + (6 - {coltype.precision}), '0')" try: is_date = coltype.is_date @@ -141,30 +146,28 @@ def _add_padding(coltype: TemporalType, timestamp6: str): null_case_end = "END" # 294277 or 4714 BC would be out of range, make sure we can't round to that - # TODO test timezones for overflow? max_timestamp = "294276-12-31 23:59:59.0000" min_timestamp = "4713-01-01 00:00:00.00 BC" - timestamp = f"least('{max_timestamp}'::timestamp(6), {value}::timestamp(6))" - timestamp = f"greatest('{min_timestamp}'::timestamp(6), {timestamp})" + ts_type = "timestamptz(6)" if isinstance(coltype, TimestampTZ) else "timestamp(6)" + timestamp = f"least('{max_timestamp}'::{ts_type}, {value}::{ts_type})" + timestamp = f"greatest('{min_timestamp}'::{ts_type}, {timestamp})" interval = format((0.5 * (10 ** (-coltype.precision))), f".{coltype.precision + 1}f") rounded_timestamp = ( - f"left(to_char(least('{max_timestamp}'::timestamp, {timestamp})" + f"left(to_char(least('{max_timestamp}'::{ts_type}, {timestamp})" f"+ interval '{interval}', 'YYYY-mm-dd HH24:MI:SS.US')," - f"length(to_char(least('{max_timestamp}'::timestamp, {timestamp})" + f"length(to_char(least('{max_timestamp}'::{ts_type}, {timestamp})" f"+ interval '{interval}', 'YYYY-mm-dd HH24:MI:SS.US')) - (6-{coltype.precision}))" ) - padded = _add_padding(coltype, rounded_timestamp) + padded = _zero_pad(coltype, rounded_timestamp) return f"{null_case_begin} {padded} {null_case_end}" - # TODO years with > 4 digits not padded correctly - # current w/ precision 6: 294276-12-31 23:59:59.0000 - # should be 294276-12-31 23:59:59.000000 else: - rounded_timestamp = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" - padded = _add_padding(coltype, rounded_timestamp) + ts_type = "timestamptz(6)" if isinstance(coltype, TimestampTZ) else "timestamp(6)" + rounded_timestamp = f"to_char({value}::{ts_type}, 'YYYY-mm-dd HH24:MI:SS.US')" + padded = _truncate_and_pad(coltype, rounded_timestamp) return padded def normalize_number(self, value: str, coltype: FractionalType) -> str: diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index c5fb8179..a1d81992 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -483,7 +483,8 @@ class Join(ExprNode, ITable, Root): def schema(self) -> Schema: if not self.columns: raise ValueError("Join must specify columns explicitly (SELECT * not yet implemented).") - s = self.source_tables[0].schema # TODO validate types match between both tables + # No cross-table type validation needed: join combines columns from both tables rather than unioning rows + s = self.source_tables[0].schema return type(s)({c.name: c.type for c in self.columns}) def on(self, *exprs) -> Self: @@ -553,15 +554,28 @@ class TableOp(ExprNode, ITable, Root): @property def type(self): - # TODO ensure types of both tables are compatible - return self.table1.type + t1 = self.table1.type + t2 = self.table2.type + if t1 is None or t2 is None: + return None + if type(t1) is not type(t2): + raise QueryBuilderError(f"Type mismatch in {self.op}: got {type(t1).__name__} and {type(t2).__name__}") + return t1 @property def schema(self) -> Schema: s1 = self.table1.schema s2 = self.table2.schema + if s1 is None or s2 is None: + raise QueryBuilderError(f"Cannot validate {self.op}: one or both tables have no schema defined") if len(s1) != len(s2): - raise ValueError(f"TableOp requires tables with matching schema lengths, got {len(s1)} and {len(s2)}.") + raise QueryBuilderError(f"Schema length mismatch in {self.op}: got {len(s1)} and {len(s2)} columns") + for (name1, type1), (name2, type2) in zip(s1.items(), s2.items()): + if type(type1) is not type(type2): + raise QueryBuilderError( + f"Type mismatch in {self.op}: column {name1!r} is {type(type1).__name__} " + f"but column {name2!r} is {type(type2).__name__}" + ) return s1 diff --git a/dev/seed/mysql/01_seed.sql b/dev/seed/mysql/01_seed.sql new file mode 100644 index 00000000..c30e7045 --- /dev/null +++ b/dev/seed/mysql/01_seed.sql @@ -0,0 +1,59 @@ +-- Seed data for demonstrating data-diff capabilities. +-- Auto-executed by MySQL on first container startup. + +CREATE TABLE ratings_source ( + id INT PRIMARY KEY, + user_id INT NOT NULL, + movie_id INT NOT NULL, + rating DECIMAL(2,1) NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE ratings_target ( + id INT PRIMARY KEY, + user_id INT NOT NULL, + movie_id INT NOT NULL, + rating DECIMAL(2,1) NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +-- Populate source with 1000 rows via stored procedure (MySQL lacks generate_series) +DELIMITER // +CREATE PROCEDURE seed_ratings() +BEGIN + DECLARE i INT DEFAULT 1; + WHILE i <= 1000 DO + INSERT INTO ratings_source (id, user_id, movie_id, rating, created_at) + VALUES ( + i, + 1 + (i % 200), + 1 + (i % 50), + 1 + (i % 5), + DATE_ADD('2025-01-01', INTERVAL i MINUTE) + ); + SET i = i + 1; + END WHILE; +END // +DELIMITER ; + +CALL seed_ratings(); +DROP PROCEDURE seed_ratings; + +-- Copy all rows into target +INSERT INTO ratings_target SELECT * FROM ratings_source; + +-- Introduce diffs: +-- 5 deleted rows (IDs 10-14 missing from target) +DELETE FROM ratings_target WHERE id BETWEEN 10 AND 14; + +-- 5 extra rows in target only (IDs 1001-1005) +INSERT INTO ratings_target (id, user_id, movie_id, rating, created_at) VALUES + (1001, 201, 51, 4.0, '2025-06-01 00:00:00'), + (1002, 202, 52, 3.0, '2025-06-02 00:00:00'), + (1003, 203, 53, 5.0, '2025-06-03 00:00:00'), + (1004, 204, 54, 2.0, '2025-06-04 00:00:00'), + (1005, 205, 55, 1.0, '2025-06-05 00:00:00'); + +-- 10 updated ratings (IDs 100-109 have different ratings in target) +UPDATE ratings_target SET rating = rating + 0.5 WHERE id BETWEEN 100 AND 109 AND rating < 5.0; +UPDATE ratings_target SET rating = 1.0 WHERE id BETWEEN 100 AND 109 AND rating >= 5.0; diff --git a/dev/seed/postgres/01_seed.sql b/dev/seed/postgres/01_seed.sql new file mode 100644 index 00000000..bb886eeb --- /dev/null +++ b/dev/seed/postgres/01_seed.sql @@ -0,0 +1,47 @@ +-- Seed data for demonstrating data-diff capabilities. +-- Auto-executed by PostgreSQL on first container startup. + +CREATE TABLE ratings_source ( + id INTEGER PRIMARY KEY, + user_id INTEGER NOT NULL, + movie_id INTEGER NOT NULL, + rating NUMERIC(2,1) NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT now() +); + +CREATE TABLE ratings_target ( + id INTEGER PRIMARY KEY, + user_id INTEGER NOT NULL, + movie_id INTEGER NOT NULL, + rating NUMERIC(2,1) NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT now() +); + +-- Populate source with 1000 rows +INSERT INTO ratings_source (id, user_id, movie_id, rating, created_at) +SELECT + g AS id, + 1 + (g % 200) AS user_id, + 1 + (g % 50) AS movie_id, + (1 + (g % 5))::NUMERIC(2,1) AS rating, + '2025-01-01'::TIMESTAMP + (g || ' minutes')::INTERVAL AS created_at +FROM generate_series(1, 1000) AS g; + +-- Copy all rows into target +INSERT INTO ratings_target SELECT * FROM ratings_source; + +-- Introduce diffs: +-- 5 deleted rows (IDs 10-14 missing from target) +DELETE FROM ratings_target WHERE id BETWEEN 10 AND 14; + +-- 5 extra rows in target only (IDs 1001-1005) +INSERT INTO ratings_target (id, user_id, movie_id, rating, created_at) VALUES + (1001, 201, 51, 4.0, '2025-06-01 00:00:00'), + (1002, 202, 52, 3.0, '2025-06-02 00:00:00'), + (1003, 203, 53, 5.0, '2025-06-03 00:00:00'), + (1004, 204, 54, 2.0, '2025-06-04 00:00:00'), + (1005, 205, 55, 1.0, '2025-06-05 00:00:00'); + +-- 10 updated ratings (IDs 100-109 have different ratings in target) +UPDATE ratings_target SET rating = rating + 0.5 WHERE id BETWEEN 100 AND 109 AND rating < 5.0; +UPDATE ratings_target SET rating = 1.0 WHERE id BETWEEN 100 AND 109 AND rating >= 5.0; diff --git a/docker-compose.yml b/docker-compose.yml index 26acc497..c69bd658 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -12,6 +12,7 @@ services: restart: always volumes: - postgresql-data:/var/lib/postgresql/data:delegated + - ./dev/seed/postgres:/docker-entrypoint-initdb.d:ro ports: - '5432:5432' expose: @@ -42,6 +43,7 @@ services: restart: always volumes: - mysql-data:/var/lib/mysql:delegated + - ./dev/seed/mysql:/docker-entrypoint-initdb.d:ro user: mysql ports: - '3306:3306' @@ -61,6 +63,7 @@ services: clickhouse: container_name: dd-clickhouse image: clickhouse/clickhouse-server:24.3 + profiles: [full] restart: always volumes: - clickhouse-data:/var/lib/clickhouse:delegated @@ -88,6 +91,7 @@ services: # prestodb.dbapi.connect(host="127.0.0.1", user="presto").cursor().execute('SELECT * FROM system.runtime.nodes') presto: + profiles: [full] container_name: dd-presto build: context: ./dev @@ -101,6 +105,7 @@ services: - local trino: + profiles: [full] container_name: dd-trino image: 'trinodb/trino:439' hostname: trino @@ -118,6 +123,7 @@ services: vertica: container_name: dd-vertica + profiles: [full] image: vertica/vertica-ce:24.1.0-0 restart: always volumes: diff --git a/tests/common.py b/tests/common.py index db9a4ba0..23717d15 100644 --- a/tests/common.py +++ b/tests/common.py @@ -27,14 +27,12 @@ TEST_BIGQUERY_CONN_STRING: str = os.environ.get("DATADIFF_BIGQUERY_URI") or None TEST_REDSHIFT_CONN_STRING: str = os.environ.get("DATADIFF_REDSHIFT_URI") or None TEST_ORACLE_CONN_STRING: str = None -TEST_DATABRICKS_CONN_STRING: str = os.environ.get("DATADIFF_DATABRICKS_URI") +TEST_DATABRICKS_CONN_STRING: str = os.environ.get("DATADIFF_DATABRICKS_URI") or None TEST_TRINO_CONN_STRING: str = os.environ.get("DATADIFF_TRINO_URI") or None -# clickhouse uri for provided docker - "clickhouse://clickhouse:Password1@localhost:9000/clickhouse" -TEST_CLICKHOUSE_CONN_STRING: str = os.environ.get("DATADIFF_CLICKHOUSE_URI") -# vertica uri provided for docker - "vertica://vertica:Password1@localhost:5433/vertica" -TEST_VERTICA_CONN_STRING: str = os.environ.get("DATADIFF_VERTICA_URI") +TEST_CLICKHOUSE_CONN_STRING: str = os.environ.get("DATADIFF_CLICKHOUSE_URI") or None +TEST_VERTICA_CONN_STRING: str = os.environ.get("DATADIFF_VERTICA_URI") or None TEST_DUCKDB_CONN_STRING: str = "duckdb://main:@:memory:" -TEST_MSSQL_CONN_STRING: str = os.environ.get("DATADIFF_MSSQL_URI") +TEST_MSSQL_CONN_STRING: str = os.environ.get("DATADIFF_MSSQL_URI") or None DEFAULT_N_SAMPLES = 50 diff --git a/tests/test_query.py b/tests/test_query.py index 722a809c..502cdd75 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -5,10 +5,11 @@ import attrs -from data_diff.abcs.database_types import FractionalType, TemporalType +from data_diff.abcs.database_types import FractionalType, Integer, TemporalType, Text, Timestamp, TimestampTZ from data_diff.databases.base import BaseDialect, CompileError, Compiler, Database +from data_diff.databases.postgresql import PostgresqlDialect from data_diff.queries.api import coalesce, code, cte, outerjoin, table, this, when -from data_diff.queries.ast_classes import QueryBuilderError, Random +from data_diff.queries.ast_classes import QueryBuilderError, Random, TableOp from data_diff.utils import CaseInsensitiveDict, CaseSensitiveDict @@ -443,3 +444,136 @@ def compile_cte(i): self.assertEqual(len(results), num_threads) with_results = [r for r in results if "WITH" in r] self.assertGreater(len(with_results), 0, "At least one result should have a WITH clause") + + +class TestTableOpTypeValidation(unittest.TestCase): + def test_union_matching_types_succeeds(self): + schema_a = CaseSensitiveDict({"x": Integer(), "y": Text()}) + schema_b = CaseSensitiveDict({"x": Integer(), "y": Text()}) + a = table("a", schema=schema_a) + b = table("b", schema=schema_b) + u = a.union(b) + # Should not raise + self.assertEqual(len(u.schema), 2) + + def test_union_mismatched_types_raises(self): + schema_a = CaseSensitiveDict({"x": Integer(), "y": Text()}) + schema_b = CaseSensitiveDict({"x": Text(), "y": Text()}) + a = table("a", schema=schema_a) + b = table("b", schema=schema_b) + u = a.union(b) + with self.assertRaises(QueryBuilderError): + _ = u.schema + + def test_intersect_mismatched_types_raises(self): + schema_a = CaseSensitiveDict({"x": Integer()}) + schema_b = CaseSensitiveDict({"x": Text()}) + a = table("a", schema=schema_a) + b = table("b", schema=schema_b) + op = a.intersect(b) + with self.assertRaises(QueryBuilderError): + _ = op.schema + + def test_minus_mismatched_types_raises(self): + schema_a = CaseSensitiveDict({"x": Integer()}) + schema_b = CaseSensitiveDict({"x": Text()}) + a = table("a", schema=schema_a) + b = table("b", schema=schema_b) + op = a.minus(b) + with self.assertRaises(QueryBuilderError): + _ = op.schema + + def test_type_property_with_none_types_passes(self): + schema_a = CaseSensitiveDict({"x": Integer()}) + schema_b = CaseSensitiveDict({"x": Integer()}) + a = table("a", schema=schema_a) + b = table("b", schema=schema_b) + u = a.union(b) + # Select-wrapped tables return None for .type — should pass without error + self.assertIsNone(u.type) + + def test_schema_length_mismatch_raises_query_builder_error(self): + schema_a = CaseSensitiveDict({"x": Integer(), "y": Text()}) + schema_b = CaseSensitiveDict({"x": Integer()}) + a = table("a", schema=schema_a) + b = table("b", schema=schema_b) + op = a.union(b) + with self.assertRaises(QueryBuilderError): + _ = op.schema + + def test_schema_mismatch_error_includes_details(self): + schema_a = CaseSensitiveDict({"col_a": Integer()}) + schema_b = CaseSensitiveDict({"col_b": Text()}) + a = table("a", schema=schema_a) + b = table("b", schema=schema_b) + op = a.union(b) + with self.assertRaises(QueryBuilderError) as ctx: + _ = op.schema + self.assertIn("col_a", str(ctx.exception)) + self.assertIn("UNION", str(ctx.exception)) + + def test_type_mismatch_raises_query_builder_error(self): + """TableOp.type raises when both sides have non-None but different types.""" + + class FakeTable: + schema = None + + def __init__(self, type_val): + self.type = type_val + + op = TableOp("UNION", FakeTable(Integer()), FakeTable(Text())) + with self.assertRaises(QueryBuilderError) as ctx: + _ = op.type + self.assertIn("UNION", str(ctx.exception)) + self.assertIn("Integer", str(ctx.exception)) + self.assertIn("Text", str(ctx.exception)) + + def test_type_matching_returns_type(self): + """TableOp.type returns the type when both sides match.""" + + class FakeTable: + schema = None + + def __init__(self, type_val): + self.type = type_val + + t = Integer() + op = TableOp("UNION", FakeTable(t), FakeTable(Integer())) + self.assertIsInstance(op.type, Integer) + + +class TestPostgresqlTimestampNormalization(unittest.TestCase): + def setUp(self): + self.dialect = PostgresqlDialect() + + def test_timestamp_uses_timestamp_cast(self): + result = self.dialect.normalize_timestamp("col", Timestamp(precision=6, rounds=True)) + self.assertIn("::timestamp(6)", result) + self.assertNotIn("::timestamptz", result) + + def test_timestamptz_uses_timestamptz_cast(self): + result = self.dialect.normalize_timestamp("col", TimestampTZ(precision=6, rounds=True)) + self.assertIn("::timestamptz(6)", result) + self.assertNotIn("::timestamp(6)", result) + + def test_rounding_padding_uses_zero_pad(self): + result = self.dialect.normalize_timestamp("col", Timestamp(precision=3, rounds=True)) + self.assertIn("length(", result) + # Rounding branch uses RPAD without LEFT (already truncated) + self.assertIn("RPAD(", result) + + def test_non_rounding_uses_truncate_and_pad(self): + result = self.dialect.normalize_timestamp("col", Timestamp(precision=3, rounds=False)) + self.assertIn("length(", result) + self.assertIn("RPAD(LEFT(", result) + self.assertNotIn("CASE WHEN", result) + + def test_non_rounding_timestamptz_uses_timestamptz_cast(self): + result = self.dialect.normalize_timestamp("col", TimestampTZ(precision=6, rounds=False)) + self.assertIn("::timestamptz(6)", result) + self.assertNotIn("::timestamp(6)", result) + + def test_precision_zero_rounding(self): + result = self.dialect.normalize_timestamp("col", Timestamp(precision=0, rounds=True)) + self.assertIn("RPAD(", result) + self.assertIn("(6 - 0)", result)