Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 106 additions & 10 deletions ir/normalize.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,86 @@ func normalizeProcedure(procedure *Procedure) {
// Same rationale as functions — see normalizeFunction. (Issue #354)
}

// splitTableColumns splits a TABLE column list by top-level commas,
// respecting nested parentheses (e.g., numeric(10, 2)).
func splitTableColumns(inner string) []string {
var parts []string
depth := 0
inQuotes := false
start := 0
for i := 0; i < len(inner); i++ {
ch := inner[i]
if inQuotes {
if ch == '"' {
if i+1 < len(inner) && inner[i+1] == '"' {
i++ // skip escaped ""
} else {
inQuotes = false
}
}
continue
}
switch ch {
case '"':
inQuotes = true
case '(':
depth++
case ')':
depth--
case ',':
if depth == 0 {
parts = append(parts, inner[start:i])
start = i + 1
}
}
}
parts = append(parts, inner[start:])
return parts
}

// splitColumnNameAndType splits a TABLE column definition like `"full name" public.mytype`
// into the column name and the type, respecting double-quoted identifiers.
func splitColumnNameAndType(colDef string) (name, typePart string) {
colDef = strings.TrimSpace(colDef)
if colDef == "" {
return "", ""
}

var nameEnd int
if colDef[0] == '"' {
// Quoted identifier: find the closing double-quote
// PostgreSQL escapes embedded quotes as ""
i := 1
for i < len(colDef) {
if colDef[i] == '"' {
if i+1 < len(colDef) && colDef[i+1] == '"' {
i += 2 // skip escaped ""
continue
}
nameEnd = i + 1
break
}
i++
}
if nameEnd == 0 {
// Unterminated quote — treat whole thing as name
return colDef, ""
}
} else {
// Unquoted identifier: ends at first whitespace
nameEnd = strings.IndexFunc(colDef, func(r rune) bool {
return r == ' ' || r == '\t'
})
if nameEnd == -1 {
return colDef, ""
}
}

name = colDef[:nameEnd]
rest := strings.TrimSpace(colDef[nameEnd:])
return name, rest
}

// normalizeFunctionReturnType normalizes function return types, especially TABLE types
func normalizeFunctionReturnType(returnType string) string {
if returnType == "" {
Expand All @@ -462,8 +542,8 @@ func normalizeFunctionReturnType(returnType string) string {
// Extract the contents inside TABLE(...)
inner := returnType[6 : len(returnType)-1] // Remove "TABLE(" and ")"

// Split by comma to process each column definition
parts := strings.Split(inner, ",")
// Split by top-level commas (respecting nested parentheses like numeric(10,2))
parts := splitTableColumns(inner)
var normalizedParts []string

for _, part := range parts {
Expand All @@ -472,13 +552,11 @@ func normalizeFunctionReturnType(returnType string) string {
continue
}

// Normalize individual column definitions (name type)
fields := strings.Fields(part)
if len(fields) >= 2 {
// Normalize the type part
typePart := strings.Join(fields[1:], " ")
// Split column definition into name and type, respecting quoted identifiers
name, typePart := splitColumnNameAndType(part)
if typePart != "" {
normalizedType := normalizePostgreSQLType(typePart)
normalizedParts = append(normalizedParts, fields[0]+" "+normalizedType)
normalizedParts = append(normalizedParts, name+" "+normalizedType)
} else {
// Just a type, normalize it
normalizedParts = append(normalizedParts, normalizePostgreSQLType(part))
Expand Down Expand Up @@ -513,8 +591,26 @@ func stripSchemaFromReturnType(returnType, schema string) string {
}

// Handle TABLE(...) return types - strip schema from individual column types
if strings.HasPrefix(returnType, "TABLE(") {
return returnType // TABLE types are already handled by normalizeFunctionReturnType
if strings.HasPrefix(returnType, "TABLE(") && strings.HasSuffix(returnType, ")") {
inner := returnType[6 : len(returnType)-1] // Remove "TABLE(" and ")"
// Split by top-level commas (respecting nested parentheses like numeric(10,2))
parts := splitTableColumns(inner)
var newParts []string
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
// Split column definition into name and type, respecting quoted identifiers
name, typePart := splitColumnNameAndType(part)
if typePart != "" {
strippedType := stripSchemaPrefix(typePart, prefix)
newParts = append(newParts, name+" "+strippedType)
} else {
newParts = append(newParts, part)
}
}
return "TABLE(" + strings.Join(newParts, ", ") + ")"
}

// Direct type name
Expand Down
160 changes: 160 additions & 0 deletions ir/normalize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,166 @@ func TestNormalizeViewStripsSchemaPrefixFromDefinition(t *testing.T) {
}
}

func TestSplitColumnNameAndType(t *testing.T) {
tests := []struct {
name string
colDef string
expectedName string
expectedType string
}{
{"simple", "id integer", "id", "integer"},
{"schema qualified type", "col public.mytype", "col", "public.mytype"},
{"quoted identifier", `"full name" text`, `"full name"`, "text"},
{"quoted with schema type", `"my col" public.mytype`, `"my col"`, "public.mytype"},
{"quoted with escaped quotes", `"it""s" integer`, `"it""s"`, "integer"},
{"name only", "id", "id", ""},
{"empty", "", "", ""},
{"multi-word type", "col character varying", "col", "character varying"},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
name, typePart := splitColumnNameAndType(tt.colDef)
if name != tt.expectedName || typePart != tt.expectedType {
t.Errorf("splitColumnNameAndType(%q) = (%q, %q), want (%q, %q)",
tt.colDef, name, typePart, tt.expectedName, tt.expectedType)
}
})
}
}

func TestSplitTableColumns(t *testing.T) {
tests := []struct {
name string
inner string
expected []string
}{
{
name: "simple columns",
inner: "id integer, name varchar",
expected: []string{"id integer", " name varchar"},
},
{
name: "numeric with precision and scale",
inner: "id integer, amount numeric(10, 2), name varchar",
expected: []string{"id integer", " amount numeric(10, 2)", " name varchar"},
},
{
name: "nested parentheses",
inner: "id integer, val numeric(10, 2), label character varying(100)",
expected: []string{"id integer", " val numeric(10, 2)", " label character varying(100)"},
},
{
name: "quoted identifier with comma",
inner: `"a,b" integer, name varchar`,
expected: []string{`"a,b" integer`, " name varchar"},
},
{
name: "quoted identifier with parenthesis",
inner: `"a(b)" integer, val numeric(10, 2)`,
expected: []string{`"a(b)" integer`, " val numeric(10, 2)"},
},
{
name: "single column",
inner: "id integer",
expected: []string{"id integer"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := splitTableColumns(tt.inner)
if len(result) != len(tt.expected) {
t.Fatalf("splitTableColumns(%q) returned %d parts, want %d: %v", tt.inner, len(result), len(tt.expected), result)
}
for i, part := range result {
if part != tt.expected[i] {
t.Errorf("splitTableColumns(%q)[%d] = %q, want %q", tt.inner, i, part, tt.expected[i])
}
}
})
}
}

func TestStripSchemaFromReturnType(t *testing.T) {
tests := []struct {
name string
returnType string
schema string
expected string
}{
{
name: "empty",
returnType: "",
schema: "public",
expected: "",
},
{
name: "simple type no prefix",
returnType: "integer",
schema: "public",
expected: "integer",
},
{
name: "simple type with prefix",
returnType: "public.mytype",
schema: "public",
expected: "mytype",
},
{
name: "SETOF with prefix",
returnType: "SETOF public.actor",
schema: "public",
expected: "SETOF actor",
},
{
name: "TABLE with custom type prefix",
returnType: "TABLE(id uuid, name varchar, created_at public.datetimeoffset)",
schema: "public",
expected: "TABLE(id uuid, name varchar, created_at datetimeoffset)",
},
{
name: "TABLE with multiple custom type prefixes",
returnType: "TABLE(id uuid, created_at public.datetimeoffset, updated_at public.datetimeoffset)",
schema: "public",
expected: "TABLE(id uuid, created_at datetimeoffset, updated_at datetimeoffset)",
},
{
name: "TABLE with no prefix to strip",
returnType: "TABLE(id uuid, name varchar)",
schema: "public",
expected: "TABLE(id uuid, name varchar)",
},
{
name: "TABLE with numeric precision (commas in parens)",
returnType: "TABLE(id integer, amount numeric(10, 2), name public.mytype)",
schema: "public",
expected: "TABLE(id integer, amount numeric(10, 2), name mytype)",
},
{
name: "array type with prefix",
returnType: "public.mytype[]",
schema: "public",
expected: "mytype[]",
},
{
name: "TABLE with quoted column name",
returnType: `TABLE("full name" public.mytype, id uuid)`,
schema: "public",
expected: `TABLE("full name" mytype, id uuid)`,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := stripSchemaFromReturnType(tt.returnType, tt.schema)
if result != tt.expected {
t.Errorf("stripSchemaFromReturnType(%q, %q) = %q, want %q", tt.returnType, tt.schema, result, tt.expected)
}
})
}
}

func TestNormalizeCheckClause(t *testing.T) {
tests := []struct {
name string
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
CREATE TYPE datetimeoffset AS (local_time timestamp without time zone, offset_minutes smallint);

CREATE TABLE account_groups (
id uuid NOT NULL,
company_id uuid NOT NULL,
name varchar NOT NULL,
created_at datetimeoffset NOT NULL,
updated_at datetimeoffset NOT NULL
);

CREATE OR REPLACE FUNCTION get_account_group_by_id(
p_group_id uuid
)
RETURNS TABLE(id uuid, company_id uuid, name varchar, created_at datetimeoffset, updated_at datetimeoffset)
LANGUAGE plpgsql
VOLATILE
SECURITY DEFINER
AS $$
BEGIN
RETURN QUERY
SELECT
ag.id,
ag.company_id,
ag.name,
ag.created_at,
ag.updated_at
FROM account_groups ag
WHERE ag.id = p_group_id;
END;
$$;
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
CREATE TYPE datetimeoffset AS (local_time timestamp without time zone, offset_minutes smallint);

CREATE TABLE account_groups (
id uuid NOT NULL,
company_id uuid NOT NULL,
name varchar NOT NULL,
created_at datetimeoffset NOT NULL,
updated_at datetimeoffset NOT NULL
);

CREATE OR REPLACE FUNCTION get_account_group_by_id(
p_group_id uuid
)
RETURNS TABLE(id uuid, company_id uuid, name varchar, created_at datetimeoffset, updated_at datetimeoffset)
LANGUAGE plpgsql
VOLATILE
SECURITY DEFINER
AS $$
BEGIN
RETURN QUERY
SELECT
ag.id,
ag.company_id,
ag.name,
ag.created_at,
ag.updated_at
FROM account_groups ag
WHERE ag.id = p_group_id;
END;
$$;
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"version": "1.0.0",
"pgschema_version": "1.7.4",
"created_at": "1970-01-01T00:00:00Z",
"source_fingerprint": {
"hash": "bc4fc478f2d7ae4cc204de3447d992dface8f485a9227504fed99b21817cb888"
},
"groups": null
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
No changes detected.
Loading