diff --git a/ir/normalize.go b/ir/normalize.go index 2b62e894..8edddac2 100644 --- a/ir/normalize.go +++ b/ir/normalize.go @@ -451,6 +451,86 @@ func normalizeProcedure(procedure *Procedure) { // Same rationale as functions — see normalizeFunction. (Issue #354) } +// splitTableColumns splits a TABLE column list by top-level commas, +// respecting nested parentheses (e.g., numeric(10, 2)). +func splitTableColumns(inner string) []string { + var parts []string + depth := 0 + inQuotes := false + start := 0 + for i := 0; i < len(inner); i++ { + ch := inner[i] + if inQuotes { + if ch == '"' { + if i+1 < len(inner) && inner[i+1] == '"' { + i++ // skip escaped "" + } else { + inQuotes = false + } + } + continue + } + switch ch { + case '"': + inQuotes = true + case '(': + depth++ + case ')': + depth-- + case ',': + if depth == 0 { + parts = append(parts, inner[start:i]) + start = i + 1 + } + } + } + parts = append(parts, inner[start:]) + return parts +} + +// splitColumnNameAndType splits a TABLE column definition like `"full name" public.mytype` +// into the column name and the type, respecting double-quoted identifiers. +func splitColumnNameAndType(colDef string) (name, typePart string) { + colDef = strings.TrimSpace(colDef) + if colDef == "" { + return "", "" + } + + var nameEnd int + if colDef[0] == '"' { + // Quoted identifier: find the closing double-quote + // PostgreSQL escapes embedded quotes as "" + i := 1 + for i < len(colDef) { + if colDef[i] == '"' { + if i+1 < len(colDef) && colDef[i+1] == '"' { + i += 2 // skip escaped "" + continue + } + nameEnd = i + 1 + break + } + i++ + } + if nameEnd == 0 { + // Unterminated quote — treat whole thing as name + return colDef, "" + } + } else { + // Unquoted identifier: ends at first whitespace + nameEnd = strings.IndexFunc(colDef, func(r rune) bool { + return r == ' ' || r == '\t' + }) + if nameEnd == -1 { + return colDef, "" + } + } + + name = colDef[:nameEnd] + rest := strings.TrimSpace(colDef[nameEnd:]) + return name, rest +} + // normalizeFunctionReturnType normalizes function return types, especially TABLE types func normalizeFunctionReturnType(returnType string) string { if returnType == "" { @@ -462,8 +542,8 @@ func normalizeFunctionReturnType(returnType string) string { // Extract the contents inside TABLE(...) inner := returnType[6 : len(returnType)-1] // Remove "TABLE(" and ")" - // Split by comma to process each column definition - parts := strings.Split(inner, ",") + // Split by top-level commas (respecting nested parentheses like numeric(10,2)) + parts := splitTableColumns(inner) var normalizedParts []string for _, part := range parts { @@ -472,13 +552,11 @@ func normalizeFunctionReturnType(returnType string) string { continue } - // Normalize individual column definitions (name type) - fields := strings.Fields(part) - if len(fields) >= 2 { - // Normalize the type part - typePart := strings.Join(fields[1:], " ") + // Split column definition into name and type, respecting quoted identifiers + name, typePart := splitColumnNameAndType(part) + if typePart != "" { normalizedType := normalizePostgreSQLType(typePart) - normalizedParts = append(normalizedParts, fields[0]+" "+normalizedType) + normalizedParts = append(normalizedParts, name+" "+normalizedType) } else { // Just a type, normalize it normalizedParts = append(normalizedParts, normalizePostgreSQLType(part)) @@ -513,8 +591,26 @@ func stripSchemaFromReturnType(returnType, schema string) string { } // Handle TABLE(...) return types - strip schema from individual column types - if strings.HasPrefix(returnType, "TABLE(") { - return returnType // TABLE types are already handled by normalizeFunctionReturnType + if strings.HasPrefix(returnType, "TABLE(") && strings.HasSuffix(returnType, ")") { + inner := returnType[6 : len(returnType)-1] // Remove "TABLE(" and ")" + // Split by top-level commas (respecting nested parentheses like numeric(10,2)) + parts := splitTableColumns(inner) + var newParts []string + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + // Split column definition into name and type, respecting quoted identifiers + name, typePart := splitColumnNameAndType(part) + if typePart != "" { + strippedType := stripSchemaPrefix(typePart, prefix) + newParts = append(newParts, name+" "+strippedType) + } else { + newParts = append(newParts, part) + } + } + return "TABLE(" + strings.Join(newParts, ", ") + ")" } // Direct type name diff --git a/ir/normalize_test.go b/ir/normalize_test.go index cfaf96be..5eb31386 100644 --- a/ir/normalize_test.go +++ b/ir/normalize_test.go @@ -161,6 +161,166 @@ func TestNormalizeViewStripsSchemaPrefixFromDefinition(t *testing.T) { } } +func TestSplitColumnNameAndType(t *testing.T) { + tests := []struct { + name string + colDef string + expectedName string + expectedType string + }{ + {"simple", "id integer", "id", "integer"}, + {"schema qualified type", "col public.mytype", "col", "public.mytype"}, + {"quoted identifier", `"full name" text`, `"full name"`, "text"}, + {"quoted with schema type", `"my col" public.mytype`, `"my col"`, "public.mytype"}, + {"quoted with escaped quotes", `"it""s" integer`, `"it""s"`, "integer"}, + {"name only", "id", "id", ""}, + {"empty", "", "", ""}, + {"multi-word type", "col character varying", "col", "character varying"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + name, typePart := splitColumnNameAndType(tt.colDef) + if name != tt.expectedName || typePart != tt.expectedType { + t.Errorf("splitColumnNameAndType(%q) = (%q, %q), want (%q, %q)", + tt.colDef, name, typePart, tt.expectedName, tt.expectedType) + } + }) + } +} + +func TestSplitTableColumns(t *testing.T) { + tests := []struct { + name string + inner string + expected []string + }{ + { + name: "simple columns", + inner: "id integer, name varchar", + expected: []string{"id integer", " name varchar"}, + }, + { + name: "numeric with precision and scale", + inner: "id integer, amount numeric(10, 2), name varchar", + expected: []string{"id integer", " amount numeric(10, 2)", " name varchar"}, + }, + { + name: "nested parentheses", + inner: "id integer, val numeric(10, 2), label character varying(100)", + expected: []string{"id integer", " val numeric(10, 2)", " label character varying(100)"}, + }, + { + name: "quoted identifier with comma", + inner: `"a,b" integer, name varchar`, + expected: []string{`"a,b" integer`, " name varchar"}, + }, + { + name: "quoted identifier with parenthesis", + inner: `"a(b)" integer, val numeric(10, 2)`, + expected: []string{`"a(b)" integer`, " val numeric(10, 2)"}, + }, + { + name: "single column", + inner: "id integer", + expected: []string{"id integer"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := splitTableColumns(tt.inner) + if len(result) != len(tt.expected) { + t.Fatalf("splitTableColumns(%q) returned %d parts, want %d: %v", tt.inner, len(result), len(tt.expected), result) + } + for i, part := range result { + if part != tt.expected[i] { + t.Errorf("splitTableColumns(%q)[%d] = %q, want %q", tt.inner, i, part, tt.expected[i]) + } + } + }) + } +} + +func TestStripSchemaFromReturnType(t *testing.T) { + tests := []struct { + name string + returnType string + schema string + expected string + }{ + { + name: "empty", + returnType: "", + schema: "public", + expected: "", + }, + { + name: "simple type no prefix", + returnType: "integer", + schema: "public", + expected: "integer", + }, + { + name: "simple type with prefix", + returnType: "public.mytype", + schema: "public", + expected: "mytype", + }, + { + name: "SETOF with prefix", + returnType: "SETOF public.actor", + schema: "public", + expected: "SETOF actor", + }, + { + name: "TABLE with custom type prefix", + returnType: "TABLE(id uuid, name varchar, created_at public.datetimeoffset)", + schema: "public", + expected: "TABLE(id uuid, name varchar, created_at datetimeoffset)", + }, + { + name: "TABLE with multiple custom type prefixes", + returnType: "TABLE(id uuid, created_at public.datetimeoffset, updated_at public.datetimeoffset)", + schema: "public", + expected: "TABLE(id uuid, created_at datetimeoffset, updated_at datetimeoffset)", + }, + { + name: "TABLE with no prefix to strip", + returnType: "TABLE(id uuid, name varchar)", + schema: "public", + expected: "TABLE(id uuid, name varchar)", + }, + { + name: "TABLE with numeric precision (commas in parens)", + returnType: "TABLE(id integer, amount numeric(10, 2), name public.mytype)", + schema: "public", + expected: "TABLE(id integer, amount numeric(10, 2), name mytype)", + }, + { + name: "array type with prefix", + returnType: "public.mytype[]", + schema: "public", + expected: "mytype[]", + }, + { + name: "TABLE with quoted column name", + returnType: `TABLE("full name" public.mytype, id uuid)`, + schema: "public", + expected: `TABLE("full name" mytype, id uuid)`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := stripSchemaFromReturnType(tt.returnType, tt.schema) + if result != tt.expected { + t.Errorf("stripSchemaFromReturnType(%q, %q) = %q, want %q", tt.returnType, tt.schema, result, tt.expected) + } + }) + } +} + func TestNormalizeCheckClause(t *testing.T) { tests := []struct { name string diff --git a/testdata/diff/create_function/issue_360_returns_table_custom_type/diff.sql b/testdata/diff/create_function/issue_360_returns_table_custom_type/diff.sql new file mode 100644 index 00000000..e69de29b diff --git a/testdata/diff/create_function/issue_360_returns_table_custom_type/new.sql b/testdata/diff/create_function/issue_360_returns_table_custom_type/new.sql new file mode 100644 index 00000000..e75034f8 --- /dev/null +++ b/testdata/diff/create_function/issue_360_returns_table_custom_type/new.sql @@ -0,0 +1,30 @@ +CREATE TYPE datetimeoffset AS (local_time timestamp without time zone, offset_minutes smallint); + +CREATE TABLE account_groups ( + id uuid NOT NULL, + company_id uuid NOT NULL, + name varchar NOT NULL, + created_at datetimeoffset NOT NULL, + updated_at datetimeoffset NOT NULL +); + +CREATE OR REPLACE FUNCTION get_account_group_by_id( + p_group_id uuid +) +RETURNS TABLE(id uuid, company_id uuid, name varchar, created_at datetimeoffset, updated_at datetimeoffset) +LANGUAGE plpgsql +VOLATILE +SECURITY DEFINER +AS $$ +BEGIN + RETURN QUERY + SELECT + ag.id, + ag.company_id, + ag.name, + ag.created_at, + ag.updated_at + FROM account_groups ag + WHERE ag.id = p_group_id; +END; +$$; diff --git a/testdata/diff/create_function/issue_360_returns_table_custom_type/old.sql b/testdata/diff/create_function/issue_360_returns_table_custom_type/old.sql new file mode 100644 index 00000000..e75034f8 --- /dev/null +++ b/testdata/diff/create_function/issue_360_returns_table_custom_type/old.sql @@ -0,0 +1,30 @@ +CREATE TYPE datetimeoffset AS (local_time timestamp without time zone, offset_minutes smallint); + +CREATE TABLE account_groups ( + id uuid NOT NULL, + company_id uuid NOT NULL, + name varchar NOT NULL, + created_at datetimeoffset NOT NULL, + updated_at datetimeoffset NOT NULL +); + +CREATE OR REPLACE FUNCTION get_account_group_by_id( + p_group_id uuid +) +RETURNS TABLE(id uuid, company_id uuid, name varchar, created_at datetimeoffset, updated_at datetimeoffset) +LANGUAGE plpgsql +VOLATILE +SECURITY DEFINER +AS $$ +BEGIN + RETURN QUERY + SELECT + ag.id, + ag.company_id, + ag.name, + ag.created_at, + ag.updated_at + FROM account_groups ag + WHERE ag.id = p_group_id; +END; +$$; diff --git a/testdata/diff/create_function/issue_360_returns_table_custom_type/plan.json b/testdata/diff/create_function/issue_360_returns_table_custom_type/plan.json new file mode 100644 index 00000000..0ae1da87 --- /dev/null +++ b/testdata/diff/create_function/issue_360_returns_table_custom_type/plan.json @@ -0,0 +1,9 @@ +{ + "version": "1.0.0", + "pgschema_version": "1.7.4", + "created_at": "1970-01-01T00:00:00Z", + "source_fingerprint": { + "hash": "bc4fc478f2d7ae4cc204de3447d992dface8f485a9227504fed99b21817cb888" + }, + "groups": null +} diff --git a/testdata/diff/create_function/issue_360_returns_table_custom_type/plan.sql b/testdata/diff/create_function/issue_360_returns_table_custom_type/plan.sql new file mode 100644 index 00000000..e69de29b diff --git a/testdata/diff/create_function/issue_360_returns_table_custom_type/plan.txt b/testdata/diff/create_function/issue_360_returns_table_custom_type/plan.txt new file mode 100644 index 00000000..241994af --- /dev/null +++ b/testdata/diff/create_function/issue_360_returns_table_custom_type/plan.txt @@ -0,0 +1 @@ +No changes detected.