From 5c2a6bae5273e316ec7ae92604a34d8b61ce47b2 Mon Sep 17 00:00:00 2001 From: Daniel Song Date: Mon, 2 Mar 2026 01:27:45 -0800 Subject: [PATCH 1/5] fix: validate column types in set operations and fix PostgreSQL timestamp edge cases Add type-class validation to TableOp.schema and TableOp.type so that unions, intersects, and minus operations reject mismatched column types early with a clear QueryBuilderError instead of silently producing incorrect results. (#5) Fix PostgreSQL timestamp normalization: use timestamptz(6) cast for TimestampTZ columns to preserve timezone info during bounds comparison, and replace hardcoded TIMESTAMP_PRECISION_POS with length()-based calculation to correctly pad years with >4 digits. (#12) Co-Authored-By: Claude Opus 4.6 --- data_diff/databases/postgresql.py | 18 ++++---- data_diff/queries/ast_classes.py | 15 +++++-- tests/test_query.py | 72 ++++++++++++++++++++++++++++++- 3 files changed, 91 insertions(+), 14 deletions(-) diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index 5211be6d..25a624f6 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -25,7 +25,6 @@ CHECKSUM_HEXDIGITS, CHECKSUM_OFFSET, MD5_HEXDIGITS, - TIMESTAMP_PRECISION_POS, BaseDialect, ConnectError, ThreadedDatabase, @@ -116,7 +115,9 @@ def md5_as_hex(self, s: str) -> str: def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: def _add_padding(coltype: TemporalType, timestamp6: str): - return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS + coltype.precision}), {TIMESTAMP_PRECISION_POS + 6}, '0')" + return ( + f"RPAD(LEFT({timestamp6}, length({timestamp6}) - (6 - {coltype.precision})), length({timestamp6}), '0')" + ) try: is_date = coltype.is_date @@ -141,27 +142,24 @@ def _add_padding(coltype: TemporalType, timestamp6: str): null_case_end = "END" # 294277 or 4714 BC would be out of range, make sure we can't round to that - # TODO test timezones for overflow? max_timestamp = "294276-12-31 23:59:59.0000" min_timestamp = "4713-01-01 00:00:00.00 BC" - timestamp = f"least('{max_timestamp}'::timestamp(6), {value}::timestamp(6))" - timestamp = f"greatest('{min_timestamp}'::timestamp(6), {timestamp})" + ts_type = "timestamptz(6)" if isinstance(coltype, TimestampTZ) else "timestamp(6)" + timestamp = f"least('{max_timestamp}'::{ts_type}, {value}::{ts_type})" + timestamp = f"greatest('{min_timestamp}'::{ts_type}, {timestamp})" interval = format((0.5 * (10 ** (-coltype.precision))), f".{coltype.precision + 1}f") rounded_timestamp = ( - f"left(to_char(least('{max_timestamp}'::timestamp, {timestamp})" + f"left(to_char(least('{max_timestamp}'::{ts_type}, {timestamp})" f"+ interval '{interval}', 'YYYY-mm-dd HH24:MI:SS.US')," - f"length(to_char(least('{max_timestamp}'::timestamp, {timestamp})" + f"length(to_char(least('{max_timestamp}'::{ts_type}, {timestamp})" f"+ interval '{interval}', 'YYYY-mm-dd HH24:MI:SS.US')) - (6-{coltype.precision}))" ) padded = _add_padding(coltype, rounded_timestamp) return f"{null_case_begin} {padded} {null_case_end}" - # TODO years with > 4 digits not padded correctly - # current w/ precision 6: 294276-12-31 23:59:59.0000 - # should be 294276-12-31 23:59:59.000000 else: rounded_timestamp = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" padded = _add_padding(coltype, rounded_timestamp) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index c5fb8179..03d53123 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -483,7 +483,7 @@ class Join(ExprNode, ITable, Root): def schema(self) -> Schema: if not self.columns: raise ValueError("Join must specify columns explicitly (SELECT * not yet implemented).") - s = self.source_tables[0].schema # TODO validate types match between both tables + s = self.source_tables[0].schema return type(s)({c.name: c.type for c in self.columns}) def on(self, *exprs) -> Self: @@ -553,8 +553,11 @@ class TableOp(ExprNode, ITable, Root): @property def type(self): - # TODO ensure types of both tables are compatible - return self.table1.type + t1 = self.table1.type + t2 = self.table2.type + if type(t1) is not type(t2): + raise QueryBuilderError(f"Type mismatch in {self.op}: got {type(t1).__name__} and {type(t2).__name__}") + return t1 @property def schema(self) -> Schema: @@ -562,6 +565,12 @@ def schema(self) -> Schema: s2 = self.table2.schema if len(s1) != len(s2): raise ValueError(f"TableOp requires tables with matching schema lengths, got {len(s1)} and {len(s2)}.") + for (name1, type1), (name2, type2) in zip(s1.items(), s2.items()): + if type(type1) is not type(type2): + raise QueryBuilderError( + f"Type mismatch in {self.op}: column {name1!r} is {type(type1).__name__} " + f"but column {name2!r} is {type(type2).__name__}" + ) return s1 diff --git a/tests/test_query.py b/tests/test_query.py index 722a809c..5d5f10b5 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -5,8 +5,9 @@ import attrs -from data_diff.abcs.database_types import FractionalType, TemporalType +from data_diff.abcs.database_types import FractionalType, Integer, TemporalType, Text, Timestamp, TimestampTZ from data_diff.databases.base import BaseDialect, CompileError, Compiler, Database +from data_diff.databases.postgresql import PostgresqlDialect from data_diff.queries.api import coalesce, code, cte, outerjoin, table, this, when from data_diff.queries.ast_classes import QueryBuilderError, Random from data_diff.utils import CaseInsensitiveDict, CaseSensitiveDict @@ -443,3 +444,72 @@ def compile_cte(i): self.assertEqual(len(results), num_threads) with_results = [r for r in results if "WITH" in r] self.assertGreater(len(with_results), 0, "At least one result should have a WITH clause") + + +class TestTableOpTypeValidation(unittest.TestCase): + def test_union_matching_types_succeeds(self): + schema_a = CaseSensitiveDict({"x": Integer(), "y": Text()}) + schema_b = CaseSensitiveDict({"x": Integer(), "y": Text()}) + a = table("a", schema=schema_a) + b = table("b", schema=schema_b) + u = a.union(b) + # Should not raise + self.assertEqual(len(u.schema), 2) + + def test_union_mismatched_types_raises(self): + schema_a = CaseSensitiveDict({"x": Integer(), "y": Text()}) + schema_b = CaseSensitiveDict({"x": Text(), "y": Text()}) + a = table("a", schema=schema_a) + b = table("b", schema=schema_b) + u = a.union(b) + with self.assertRaises(QueryBuilderError): + _ = u.schema + + def test_intersect_mismatched_types_raises(self): + schema_a = CaseSensitiveDict({"x": Integer()}) + schema_b = CaseSensitiveDict({"x": Text()}) + a = table("a", schema=schema_a) + b = table("b", schema=schema_b) + op = a.intersect(b) + with self.assertRaises(QueryBuilderError): + _ = op.schema + + def test_minus_mismatched_types_raises(self): + schema_a = CaseSensitiveDict({"x": Integer()}) + schema_b = CaseSensitiveDict({"x": Text()}) + a = table("a", schema=schema_a) + b = table("b", schema=schema_b) + op = a.minus(b) + with self.assertRaises(QueryBuilderError): + _ = op.schema + + def test_type_property_validates_when_types_available(self): + # TableOp.type delegates to table1/table2.type; Select returns None + # so type validation is a pass-through for Select-wrapped tables. + # The real enforcement is in .schema validation above. + schema_a = CaseSensitiveDict({"x": Integer()}) + schema_b = CaseSensitiveDict({"x": Integer()}) + a = table("a", schema=schema_a) + b = table("b", schema=schema_b) + u = a.union(b) + # Both raw tables return None for .type, so no mismatch + self.assertIsNone(u.type) + + +class TestPostgresqlTimestampNormalization(unittest.TestCase): + def setUp(self): + self.dialect = PostgresqlDialect() + + def test_timestamp_uses_timestamp_cast(self): + result = self.dialect.normalize_timestamp("col", Timestamp(precision=6, rounds=True)) + self.assertIn("::timestamp(6)", result) + self.assertNotIn("::timestamptz", result) + + def test_timestamptz_uses_timestamptz_cast(self): + result = self.dialect.normalize_timestamp("col", TimestampTZ(precision=6, rounds=True)) + self.assertIn("::timestamptz(6)", result) + + def test_padding_uses_length_based_calculation(self): + result = self.dialect.normalize_timestamp("col", Timestamp(precision=3, rounds=True)) + self.assertIn("length(", result) + self.assertIn("RPAD(LEFT(", result) From d0021615d9476a726f43e6c0039e7f5e62c55fc7 Mon Sep 17 00:00:00 2001 From: Daniel Song Date: Mon, 2 Mar 2026 01:38:10 -0800 Subject: [PATCH 2/5] feat: add Docker Compose local test infrastructure with seed data Add SQL seed data (PostgreSQL + MySQL) with ~1000 rows and deliberate diffs for showcasing data-diff. Default connection strings for all docker-compose databases, add profiles to keep lightweight default (PG + MySQL only), and add Makefile for developer ergonomics. Co-Authored-By: Claude Opus 4.6 --- .github/workflows/ci.yml | 2 -- Makefile | 41 ++++++++++++++++++++++++ dev/seed/mysql/01_seed.sql | 59 +++++++++++++++++++++++++++++++++++ dev/seed/postgres/01_seed.sql | 47 ++++++++++++++++++++++++++++ docker-compose.yml | 6 ++++ tests/common.py | 14 +++++---- 6 files changed, 161 insertions(+), 8 deletions(-) create mode 100644 Makefile create mode 100644 dev/seed/mysql/01_seed.sql create mode 100644 dev/seed/postgres/01_seed.sql diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4746b0f4..700ccde1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -45,8 +45,6 @@ jobs: run: docker compose up -d --wait mysql postgres presto trino clickhouse - name: Run tests - env: - DATADIFF_CLICKHOUSE_URI: "clickhouse://clickhouse:Password1@localhost:9000/clickhouse" run: | uv run pytest tests/ \ -o addopts="--timeout=300 --tb=short" \ diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..eca715af --- /dev/null +++ b/Makefile @@ -0,0 +1,41 @@ +.PHONY: up up-full down test-unit test demo + +## Start PostgreSQL + MySQL (lightweight, fast startup) +up: + docker compose up -d --wait postgres mysql + +## Start all services including ClickHouse, Presto, Trino, Vertica +up-full: + docker compose --profile full up -d --wait + +## Stop all services and remove volumes +down: + docker compose --profile full down -v + +## Run unit tests (no database required) +test-unit: + uv run pytest tests/test_query.py tests/test_utils.py -x + +## Run full test suite against PG + MySQL (starts containers if needed) +test: up + uv run pytest tests/ \ + -o addopts="--timeout=300 --tb=short" \ + --ignore=tests/test_database_types.py \ + --ignore=tests/test_dbt_config_validators.py \ + --ignore=tests/test_main.py + +## Run data-diff against seed data to showcase diffing +demo: up + @echo "=== PostgreSQL: ratings_source vs ratings_target ===" + uv run python -m data_diff \ + postgresql://postgres:Password1@localhost/postgres \ + ratings_source ratings_target \ + --key-columns id \ + --columns rating + @echo "" + @echo "=== MySQL: ratings_source vs ratings_target ===" + uv run python -m data_diff \ + mysql://mysql:Password1@localhost/mysql \ + ratings_source ratings_target \ + --key-columns id \ + --columns rating diff --git a/dev/seed/mysql/01_seed.sql b/dev/seed/mysql/01_seed.sql new file mode 100644 index 00000000..c30e7045 --- /dev/null +++ b/dev/seed/mysql/01_seed.sql @@ -0,0 +1,59 @@ +-- Seed data for demonstrating data-diff capabilities. +-- Auto-executed by MySQL on first container startup. + +CREATE TABLE ratings_source ( + id INT PRIMARY KEY, + user_id INT NOT NULL, + movie_id INT NOT NULL, + rating DECIMAL(2,1) NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE ratings_target ( + id INT PRIMARY KEY, + user_id INT NOT NULL, + movie_id INT NOT NULL, + rating DECIMAL(2,1) NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +-- Populate source with 1000 rows via stored procedure (MySQL lacks generate_series) +DELIMITER // +CREATE PROCEDURE seed_ratings() +BEGIN + DECLARE i INT DEFAULT 1; + WHILE i <= 1000 DO + INSERT INTO ratings_source (id, user_id, movie_id, rating, created_at) + VALUES ( + i, + 1 + (i % 200), + 1 + (i % 50), + 1 + (i % 5), + DATE_ADD('2025-01-01', INTERVAL i MINUTE) + ); + SET i = i + 1; + END WHILE; +END // +DELIMITER ; + +CALL seed_ratings(); +DROP PROCEDURE seed_ratings; + +-- Copy all rows into target +INSERT INTO ratings_target SELECT * FROM ratings_source; + +-- Introduce diffs: +-- 5 deleted rows (IDs 10-14 missing from target) +DELETE FROM ratings_target WHERE id BETWEEN 10 AND 14; + +-- 5 extra rows in target only (IDs 1001-1005) +INSERT INTO ratings_target (id, user_id, movie_id, rating, created_at) VALUES + (1001, 201, 51, 4.0, '2025-06-01 00:00:00'), + (1002, 202, 52, 3.0, '2025-06-02 00:00:00'), + (1003, 203, 53, 5.0, '2025-06-03 00:00:00'), + (1004, 204, 54, 2.0, '2025-06-04 00:00:00'), + (1005, 205, 55, 1.0, '2025-06-05 00:00:00'); + +-- 10 updated ratings (IDs 100-109 have different ratings in target) +UPDATE ratings_target SET rating = rating + 0.5 WHERE id BETWEEN 100 AND 109 AND rating < 5.0; +UPDATE ratings_target SET rating = 1.0 WHERE id BETWEEN 100 AND 109 AND rating >= 5.0; diff --git a/dev/seed/postgres/01_seed.sql b/dev/seed/postgres/01_seed.sql new file mode 100644 index 00000000..bb886eeb --- /dev/null +++ b/dev/seed/postgres/01_seed.sql @@ -0,0 +1,47 @@ +-- Seed data for demonstrating data-diff capabilities. +-- Auto-executed by PostgreSQL on first container startup. + +CREATE TABLE ratings_source ( + id INTEGER PRIMARY KEY, + user_id INTEGER NOT NULL, + movie_id INTEGER NOT NULL, + rating NUMERIC(2,1) NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT now() +); + +CREATE TABLE ratings_target ( + id INTEGER PRIMARY KEY, + user_id INTEGER NOT NULL, + movie_id INTEGER NOT NULL, + rating NUMERIC(2,1) NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT now() +); + +-- Populate source with 1000 rows +INSERT INTO ratings_source (id, user_id, movie_id, rating, created_at) +SELECT + g AS id, + 1 + (g % 200) AS user_id, + 1 + (g % 50) AS movie_id, + (1 + (g % 5))::NUMERIC(2,1) AS rating, + '2025-01-01'::TIMESTAMP + (g || ' minutes')::INTERVAL AS created_at +FROM generate_series(1, 1000) AS g; + +-- Copy all rows into target +INSERT INTO ratings_target SELECT * FROM ratings_source; + +-- Introduce diffs: +-- 5 deleted rows (IDs 10-14 missing from target) +DELETE FROM ratings_target WHERE id BETWEEN 10 AND 14; + +-- 5 extra rows in target only (IDs 1001-1005) +INSERT INTO ratings_target (id, user_id, movie_id, rating, created_at) VALUES + (1001, 201, 51, 4.0, '2025-06-01 00:00:00'), + (1002, 202, 52, 3.0, '2025-06-02 00:00:00'), + (1003, 203, 53, 5.0, '2025-06-03 00:00:00'), + (1004, 204, 54, 2.0, '2025-06-04 00:00:00'), + (1005, 205, 55, 1.0, '2025-06-05 00:00:00'); + +-- 10 updated ratings (IDs 100-109 have different ratings in target) +UPDATE ratings_target SET rating = rating + 0.5 WHERE id BETWEEN 100 AND 109 AND rating < 5.0; +UPDATE ratings_target SET rating = 1.0 WHERE id BETWEEN 100 AND 109 AND rating >= 5.0; diff --git a/docker-compose.yml b/docker-compose.yml index 26acc497..c69bd658 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -12,6 +12,7 @@ services: restart: always volumes: - postgresql-data:/var/lib/postgresql/data:delegated + - ./dev/seed/postgres:/docker-entrypoint-initdb.d:ro ports: - '5432:5432' expose: @@ -42,6 +43,7 @@ services: restart: always volumes: - mysql-data:/var/lib/mysql:delegated + - ./dev/seed/mysql:/docker-entrypoint-initdb.d:ro user: mysql ports: - '3306:3306' @@ -61,6 +63,7 @@ services: clickhouse: container_name: dd-clickhouse image: clickhouse/clickhouse-server:24.3 + profiles: [full] restart: always volumes: - clickhouse-data:/var/lib/clickhouse:delegated @@ -88,6 +91,7 @@ services: # prestodb.dbapi.connect(host="127.0.0.1", user="presto").cursor().execute('SELECT * FROM system.runtime.nodes') presto: + profiles: [full] container_name: dd-presto build: context: ./dev @@ -101,6 +105,7 @@ services: - local trino: + profiles: [full] container_name: dd-trino image: 'trinodb/trino:439' hostname: trino @@ -118,6 +123,7 @@ services: vertica: container_name: dd-vertica + profiles: [full] image: vertica/vertica-ce:24.1.0-0 restart: always volumes: diff --git a/tests/common.py b/tests/common.py index db9a4ba0..805e23dc 100644 --- a/tests/common.py +++ b/tests/common.py @@ -23,16 +23,18 @@ os.environ.get("DATADIFF_POSTGRESQL_URI") or "postgresql://postgres:Password1@localhost/postgres" ) TEST_SNOWFLAKE_CONN_STRING: str = os.environ.get("DATADIFF_SNOWFLAKE_URI") or None -TEST_PRESTO_CONN_STRING: str = os.environ.get("DATADIFF_PRESTO_URI") or None +TEST_PRESTO_CONN_STRING: str = os.environ.get("DATADIFF_PRESTO_URI") or "presto://test@localhost:8080/memory/default" TEST_BIGQUERY_CONN_STRING: str = os.environ.get("DATADIFF_BIGQUERY_URI") or None TEST_REDSHIFT_CONN_STRING: str = os.environ.get("DATADIFF_REDSHIFT_URI") or None TEST_ORACLE_CONN_STRING: str = None TEST_DATABRICKS_CONN_STRING: str = os.environ.get("DATADIFF_DATABRICKS_URI") -TEST_TRINO_CONN_STRING: str = os.environ.get("DATADIFF_TRINO_URI") or None -# clickhouse uri for provided docker - "clickhouse://clickhouse:Password1@localhost:9000/clickhouse" -TEST_CLICKHOUSE_CONN_STRING: str = os.environ.get("DATADIFF_CLICKHOUSE_URI") -# vertica uri provided for docker - "vertica://vertica:Password1@localhost:5433/vertica" -TEST_VERTICA_CONN_STRING: str = os.environ.get("DATADIFF_VERTICA_URI") +TEST_TRINO_CONN_STRING: str = os.environ.get("DATADIFF_TRINO_URI") or "trino://test@localhost:8081/memory/default" +TEST_CLICKHOUSE_CONN_STRING: str = ( + os.environ.get("DATADIFF_CLICKHOUSE_URI") or "clickhouse://clickhouse:Password1@localhost:9000/clickhouse" +) +TEST_VERTICA_CONN_STRING: str = ( + os.environ.get("DATADIFF_VERTICA_URI") or "vertica://vertica:Password1@localhost:5433/vertica" +) TEST_DUCKDB_CONN_STRING: str = "duckdb://main:@:memory:" TEST_MSSQL_CONN_STRING: str = os.environ.get("DATADIFF_MSSQL_URI") From 37d1a5135aba3469c3342a3243ee5c0db1d1f183 Mon Sep 17 00:00:00 2001 From: Daniel Song Date: Mon, 2 Mar 2026 01:49:36 -0800 Subject: [PATCH 3/5] fix: address code review findings from PR #24 Critical: - Fix _add_padding double-truncation regression for rounding branch (split into _truncate_and_pad and _zero_pad for correct behavior) Important: - Fix non-rounding timestamp path to use timestamptz cast for TimestampTZ - Add None guard to TableOp.type to avoid misleading errors - Use QueryBuilderError consistently for schema length mismatch - Revert Presto/Trino/Vertica conn defaults to None (CI doesn't test them) - Remove unused Presto/Trino from CI docker compose command - Add comprehensive tests for all timestamp paths and edge cases Co-Authored-By: Claude Opus 4.6 --- .github/workflows/ci.yml | 2 +- Makefile | 6 +++- data_diff/databases/postgresql.py | 19 ++++++++----- data_diff/queries/ast_classes.py | 4 ++- tests/common.py | 8 ++---- tests/test_query.py | 47 +++++++++++++++++++++++++++---- 6 files changed, 65 insertions(+), 21 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 700ccde1..74645ebf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -42,7 +42,7 @@ jobs: run: uv tool run ty check --python-version 3.10 - name: Build the stack - run: docker compose up -d --wait mysql postgres presto trino clickhouse + run: docker compose --profile full up -d --wait mysql postgres clickhouse - name: Run tests run: | diff --git a/Makefile b/Makefile index eca715af..1546d874 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,11 @@ down: test-unit: uv run pytest tests/test_query.py tests/test_utils.py -x -## Run full test suite against PG + MySQL (starts containers if needed) +## Run full test suite against PG + MySQL + ClickHouse (starts containers if needed) +## To also test Presto/Trino/Vertica, run `make up-full` first and set: +## export DATADIFF_PRESTO_URI="presto://test@localhost:8080/memory/default" +## export DATADIFF_TRINO_URI="trino://test@localhost:8081/memory/default" +## export DATADIFF_VERTICA_URI="vertica://vertica:Password1@localhost:5433/vertica" test: up uv run pytest tests/ \ -o addopts="--timeout=300 --tb=short" \ diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index 25a624f6..46eeedd1 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -114,10 +114,14 @@ def md5_as_hex(self, s: str) -> str: return f"md5({s})" def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - def _add_padding(coltype: TemporalType, timestamp6: str): - return ( - f"RPAD(LEFT({timestamp6}, length({timestamp6}) - (6 - {coltype.precision})), length({timestamp6}), '0')" - ) + def _truncate_and_pad(coltype: TemporalType, timestamp6: str): + """Truncate a 6-digit-precision timestamp to target precision, then zero-pad back to 6 digits.""" + truncated = f"LEFT({timestamp6}, length({timestamp6}) - (6 - {coltype.precision}))" + return f"RPAD({truncated}, length({timestamp6}), '0')" + + def _zero_pad(coltype: TemporalType, already_truncated: str): + """Zero-pad an already-truncated timestamp back to 6 fractional digits.""" + return f"RPAD({already_truncated}, length({already_truncated}) + (6 - {coltype.precision}), '0')" try: is_date = coltype.is_date @@ -157,12 +161,13 @@ def _add_padding(coltype: TemporalType, timestamp6: str): f"+ interval '{interval}', 'YYYY-mm-dd HH24:MI:SS.US')) - (6-{coltype.precision}))" ) - padded = _add_padding(coltype, rounded_timestamp) + padded = _zero_pad(coltype, rounded_timestamp) return f"{null_case_begin} {padded} {null_case_end}" else: - rounded_timestamp = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" - padded = _add_padding(coltype, rounded_timestamp) + ts_type = "timestamptz(6)" if isinstance(coltype, TimestampTZ) else "timestamp(6)" + rounded_timestamp = f"to_char({value}::{ts_type}, 'YYYY-mm-dd HH24:MI:SS.US')" + padded = _truncate_and_pad(coltype, rounded_timestamp) return padded def normalize_number(self, value: str, coltype: FractionalType) -> str: diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 03d53123..ddb0cb42 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -555,6 +555,8 @@ class TableOp(ExprNode, ITable, Root): def type(self): t1 = self.table1.type t2 = self.table2.type + if t1 is None or t2 is None: + return t1 or t2 if type(t1) is not type(t2): raise QueryBuilderError(f"Type mismatch in {self.op}: got {type(t1).__name__} and {type(t2).__name__}") return t1 @@ -564,7 +566,7 @@ def schema(self) -> Schema: s1 = self.table1.schema s2 = self.table2.schema if len(s1) != len(s2): - raise ValueError(f"TableOp requires tables with matching schema lengths, got {len(s1)} and {len(s2)}.") + raise QueryBuilderError(f"Schema length mismatch in {self.op}: got {len(s1)} and {len(s2)} columns") for (name1, type1), (name2, type2) in zip(s1.items(), s2.items()): if type(type1) is not type(type2): raise QueryBuilderError( diff --git a/tests/common.py b/tests/common.py index 805e23dc..33f9dac6 100644 --- a/tests/common.py +++ b/tests/common.py @@ -23,18 +23,16 @@ os.environ.get("DATADIFF_POSTGRESQL_URI") or "postgresql://postgres:Password1@localhost/postgres" ) TEST_SNOWFLAKE_CONN_STRING: str = os.environ.get("DATADIFF_SNOWFLAKE_URI") or None -TEST_PRESTO_CONN_STRING: str = os.environ.get("DATADIFF_PRESTO_URI") or "presto://test@localhost:8080/memory/default" +TEST_PRESTO_CONN_STRING: str = os.environ.get("DATADIFF_PRESTO_URI") or None TEST_BIGQUERY_CONN_STRING: str = os.environ.get("DATADIFF_BIGQUERY_URI") or None TEST_REDSHIFT_CONN_STRING: str = os.environ.get("DATADIFF_REDSHIFT_URI") or None TEST_ORACLE_CONN_STRING: str = None TEST_DATABRICKS_CONN_STRING: str = os.environ.get("DATADIFF_DATABRICKS_URI") -TEST_TRINO_CONN_STRING: str = os.environ.get("DATADIFF_TRINO_URI") or "trino://test@localhost:8081/memory/default" +TEST_TRINO_CONN_STRING: str = os.environ.get("DATADIFF_TRINO_URI") or None TEST_CLICKHOUSE_CONN_STRING: str = ( os.environ.get("DATADIFF_CLICKHOUSE_URI") or "clickhouse://clickhouse:Password1@localhost:9000/clickhouse" ) -TEST_VERTICA_CONN_STRING: str = ( - os.environ.get("DATADIFF_VERTICA_URI") or "vertica://vertica:Password1@localhost:5433/vertica" -) +TEST_VERTICA_CONN_STRING: str = os.environ.get("DATADIFF_VERTICA_URI") or None TEST_DUCKDB_CONN_STRING: str = "duckdb://main:@:memory:" TEST_MSSQL_CONN_STRING: str = os.environ.get("DATADIFF_MSSQL_URI") diff --git a/tests/test_query.py b/tests/test_query.py index 5d5f10b5..46fe7156 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -483,18 +483,35 @@ def test_minus_mismatched_types_raises(self): with self.assertRaises(QueryBuilderError): _ = op.schema - def test_type_property_validates_when_types_available(self): - # TableOp.type delegates to table1/table2.type; Select returns None - # so type validation is a pass-through for Select-wrapped tables. - # The real enforcement is in .schema validation above. + def test_type_property_with_none_types_passes(self): schema_a = CaseSensitiveDict({"x": Integer()}) schema_b = CaseSensitiveDict({"x": Integer()}) a = table("a", schema=schema_a) b = table("b", schema=schema_b) u = a.union(b) - # Both raw tables return None for .type, so no mismatch + # Select-wrapped tables return None for .type — should pass without error self.assertIsNone(u.type) + def test_schema_length_mismatch_raises_query_builder_error(self): + schema_a = CaseSensitiveDict({"x": Integer(), "y": Text()}) + schema_b = CaseSensitiveDict({"x": Integer()}) + a = table("a", schema=schema_a) + b = table("b", schema=schema_b) + op = a.union(b) + with self.assertRaises(QueryBuilderError): + _ = op.schema + + def test_schema_mismatch_error_includes_details(self): + schema_a = CaseSensitiveDict({"col_a": Integer()}) + schema_b = CaseSensitiveDict({"col_b": Text()}) + a = table("a", schema=schema_a) + b = table("b", schema=schema_b) + op = a.union(b) + with self.assertRaises(QueryBuilderError) as ctx: + _ = op.schema + self.assertIn("col_a", str(ctx.exception)) + self.assertIn("UNION", str(ctx.exception)) + class TestPostgresqlTimestampNormalization(unittest.TestCase): def setUp(self): @@ -508,8 +525,26 @@ def test_timestamp_uses_timestamp_cast(self): def test_timestamptz_uses_timestamptz_cast(self): result = self.dialect.normalize_timestamp("col", TimestampTZ(precision=6, rounds=True)) self.assertIn("::timestamptz(6)", result) + self.assertNotIn("::timestamp(6)", result) - def test_padding_uses_length_based_calculation(self): + def test_rounding_padding_uses_zero_pad(self): result = self.dialect.normalize_timestamp("col", Timestamp(precision=3, rounds=True)) self.assertIn("length(", result) + # Rounding branch uses RPAD without LEFT (already truncated) + self.assertIn("RPAD(", result) + + def test_non_rounding_uses_truncate_and_pad(self): + result = self.dialect.normalize_timestamp("col", Timestamp(precision=3, rounds=False)) + self.assertIn("length(", result) self.assertIn("RPAD(LEFT(", result) + self.assertNotIn("CASE WHEN", result) + + def test_non_rounding_timestamptz_uses_timestamptz_cast(self): + result = self.dialect.normalize_timestamp("col", TimestampTZ(precision=6, rounds=False)) + self.assertIn("::timestamptz(6)", result) + self.assertNotIn("::timestamp(6)", result) + + def test_precision_zero_rounding(self): + result = self.dialect.normalize_timestamp("col", Timestamp(precision=0, rounds=True)) + self.assertIn("RPAD(", result) + self.assertIn("(6 - 0)", result) From c3d3026f15e84a799ec1b0e93d9bac4ec9ed21b9 Mon Sep 17 00:00:00 2001 From: Daniel Song Date: Mon, 2 Mar 2026 01:57:54 -0800 Subject: [PATCH 4/5] fix: address PR review findings for robustness and developer ergonomics - Revert ClickHouse default conn string to None so `make test` skips ClickHouse when the container isn't running; set URI explicitly in CI - Add None-schema guard in TableOp.schema with clear error message - Return None (not optimistic type) when one side of TableOp.type is unknown - Fix Makefile comment to accurately reflect PG + MySQL (not ClickHouse) - Add comment explaining why Join.schema skips cross-table type validation - Add tests for TableOp.type mismatch and matching branches Co-Authored-By: Claude Opus 4.6 --- .github/workflows/ci.yml | 2 ++ Makefile | 2 +- data_diff/queries/ast_classes.py | 5 ++++- tests/common.py | 4 +--- tests/test_query.py | 31 ++++++++++++++++++++++++++++++- 5 files changed, 38 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 74645ebf..b3ab63f3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -45,6 +45,8 @@ jobs: run: docker compose --profile full up -d --wait mysql postgres clickhouse - name: Run tests + env: + DATADIFF_CLICKHOUSE_URI: "clickhouse://clickhouse:Password1@localhost:9000/clickhouse" run: | uv run pytest tests/ \ -o addopts="--timeout=300 --tb=short" \ diff --git a/Makefile b/Makefile index 1546d874..b7e636f8 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ down: test-unit: uv run pytest tests/test_query.py tests/test_utils.py -x -## Run full test suite against PG + MySQL + ClickHouse (starts containers if needed) +## Run full test suite against PG + MySQL (starts containers if needed) ## To also test Presto/Trino/Vertica, run `make up-full` first and set: ## export DATADIFF_PRESTO_URI="presto://test@localhost:8080/memory/default" ## export DATADIFF_TRINO_URI="trino://test@localhost:8081/memory/default" diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index ddb0cb42..a1d81992 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -483,6 +483,7 @@ class Join(ExprNode, ITable, Root): def schema(self) -> Schema: if not self.columns: raise ValueError("Join must specify columns explicitly (SELECT * not yet implemented).") + # No cross-table type validation needed: join combines columns from both tables rather than unioning rows s = self.source_tables[0].schema return type(s)({c.name: c.type for c in self.columns}) @@ -556,7 +557,7 @@ def type(self): t1 = self.table1.type t2 = self.table2.type if t1 is None or t2 is None: - return t1 or t2 + return None if type(t1) is not type(t2): raise QueryBuilderError(f"Type mismatch in {self.op}: got {type(t1).__name__} and {type(t2).__name__}") return t1 @@ -565,6 +566,8 @@ def type(self): def schema(self) -> Schema: s1 = self.table1.schema s2 = self.table2.schema + if s1 is None or s2 is None: + raise QueryBuilderError(f"Cannot validate {self.op}: one or both tables have no schema defined") if len(s1) != len(s2): raise QueryBuilderError(f"Schema length mismatch in {self.op}: got {len(s1)} and {len(s2)} columns") for (name1, type1), (name2, type2) in zip(s1.items(), s2.items()): diff --git a/tests/common.py b/tests/common.py index 33f9dac6..11c7f152 100644 --- a/tests/common.py +++ b/tests/common.py @@ -29,9 +29,7 @@ TEST_ORACLE_CONN_STRING: str = None TEST_DATABRICKS_CONN_STRING: str = os.environ.get("DATADIFF_DATABRICKS_URI") TEST_TRINO_CONN_STRING: str = os.environ.get("DATADIFF_TRINO_URI") or None -TEST_CLICKHOUSE_CONN_STRING: str = ( - os.environ.get("DATADIFF_CLICKHOUSE_URI") or "clickhouse://clickhouse:Password1@localhost:9000/clickhouse" -) +TEST_CLICKHOUSE_CONN_STRING: str = os.environ.get("DATADIFF_CLICKHOUSE_URI") or None TEST_VERTICA_CONN_STRING: str = os.environ.get("DATADIFF_VERTICA_URI") or None TEST_DUCKDB_CONN_STRING: str = "duckdb://main:@:memory:" TEST_MSSQL_CONN_STRING: str = os.environ.get("DATADIFF_MSSQL_URI") diff --git a/tests/test_query.py b/tests/test_query.py index 46fe7156..502cdd75 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -9,7 +9,7 @@ from data_diff.databases.base import BaseDialect, CompileError, Compiler, Database from data_diff.databases.postgresql import PostgresqlDialect from data_diff.queries.api import coalesce, code, cte, outerjoin, table, this, when -from data_diff.queries.ast_classes import QueryBuilderError, Random +from data_diff.queries.ast_classes import QueryBuilderError, Random, TableOp from data_diff.utils import CaseInsensitiveDict, CaseSensitiveDict @@ -512,6 +512,35 @@ def test_schema_mismatch_error_includes_details(self): self.assertIn("col_a", str(ctx.exception)) self.assertIn("UNION", str(ctx.exception)) + def test_type_mismatch_raises_query_builder_error(self): + """TableOp.type raises when both sides have non-None but different types.""" + + class FakeTable: + schema = None + + def __init__(self, type_val): + self.type = type_val + + op = TableOp("UNION", FakeTable(Integer()), FakeTable(Text())) + with self.assertRaises(QueryBuilderError) as ctx: + _ = op.type + self.assertIn("UNION", str(ctx.exception)) + self.assertIn("Integer", str(ctx.exception)) + self.assertIn("Text", str(ctx.exception)) + + def test_type_matching_returns_type(self): + """TableOp.type returns the type when both sides match.""" + + class FakeTable: + schema = None + + def __init__(self, type_val): + self.type = type_val + + t = Integer() + op = TableOp("UNION", FakeTable(t), FakeTable(Integer())) + self.assertIsInstance(op.type, Integer) + class TestPostgresqlTimestampNormalization(unittest.TestCase): def setUp(self): From 5cf4de3271fa696d92a922180c72a1de25bf9c55 Mon Sep 17 00:00:00 2001 From: Daniel Song Date: Mon, 2 Mar 2026 02:05:06 -0800 Subject: [PATCH 5/5] fix: add CI comment for profile flag and normalize conn string defaults - Add comment explaining why --profile full is needed (ClickHouse is profile-gated; only explicitly named services start) - Add `or None` to Databricks and MsSQL conn strings to handle empty env vars consistently with all other optional databases Co-Authored-By: Claude Opus 4.6 --- .github/workflows/ci.yml | 1 + tests/common.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b3ab63f3..be233755 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -42,6 +42,7 @@ jobs: run: uv tool run ty check --python-version 3.10 - name: Build the stack + # --profile full unlocks profile-gated clickhouse; only named services start run: docker compose --profile full up -d --wait mysql postgres clickhouse - name: Run tests diff --git a/tests/common.py b/tests/common.py index 11c7f152..23717d15 100644 --- a/tests/common.py +++ b/tests/common.py @@ -27,12 +27,12 @@ TEST_BIGQUERY_CONN_STRING: str = os.environ.get("DATADIFF_BIGQUERY_URI") or None TEST_REDSHIFT_CONN_STRING: str = os.environ.get("DATADIFF_REDSHIFT_URI") or None TEST_ORACLE_CONN_STRING: str = None -TEST_DATABRICKS_CONN_STRING: str = os.environ.get("DATADIFF_DATABRICKS_URI") +TEST_DATABRICKS_CONN_STRING: str = os.environ.get("DATADIFF_DATABRICKS_URI") or None TEST_TRINO_CONN_STRING: str = os.environ.get("DATADIFF_TRINO_URI") or None TEST_CLICKHOUSE_CONN_STRING: str = os.environ.get("DATADIFF_CLICKHOUSE_URI") or None TEST_VERTICA_CONN_STRING: str = os.environ.get("DATADIFF_VERTICA_URI") or None TEST_DUCKDB_CONN_STRING: str = "duckdb://main:@:memory:" -TEST_MSSQL_CONN_STRING: str = os.environ.get("DATADIFF_MSSQL_URI") +TEST_MSSQL_CONN_STRING: str = os.environ.get("DATADIFF_MSSQL_URI") or None DEFAULT_N_SAMPLES = 50