Skip to content

Commit e4464b8

Browse files
tianzhouclaude
andauthored
fix: strip schema qualifiers from custom types in RETURNS TABLE (#360) (#361)
* fix: strip schema qualifiers from custom types in RETURNS TABLE (#360) When a function uses RETURNS TABLE with custom types (domains, composite types, enums), pg_get_function_result may schema-qualify those types depending on search_path. The normalization in stripSchemaFromReturnType was skipping TABLE return types entirely, causing a false diff between desired and current state when planning. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix: use parenthesis-aware splitting for TABLE column types Address review feedback: strings.Split(inner, ",") breaks for types like numeric(10, 2) where commas appear inside parentheses. Extract a splitTableColumns helper that tracks parenthesis depth, and use it in both normalizeFunctionReturnType and stripSchemaFromReturnType. Also add unit tests for splitTableColumns and stripSchemaFromReturnType to directly verify the schema-stripping logic. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix: handle quoted column identifiers in TABLE return types Address Copilot review: strings.Fields breaks for quoted identifiers with spaces like "full name" in RETURNS TABLE. Extract a splitColumnNameAndType helper that respects double-quoted identifiers, and use it in both normalizeFunctionReturnType and stripSchemaFromReturnType. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix: respect quoted identifiers in splitTableColumns Address review: splitTableColumns now tracks double-quoted state so commas and parentheses inside quoted identifiers (e.g., "a,b") are not treated as delimiters. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 53b055b commit e4464b8

8 files changed

Lines changed: 336 additions & 10 deletions

File tree

ir/normalize.go

Lines changed: 106 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,86 @@ func normalizeProcedure(procedure *Procedure) {
451451
// Same rationale as functions — see normalizeFunction. (Issue #354)
452452
}
453453

454+
// splitTableColumns splits a TABLE column list by top-level commas,
455+
// respecting nested parentheses (e.g., numeric(10, 2)).
456+
func splitTableColumns(inner string) []string {
457+
var parts []string
458+
depth := 0
459+
inQuotes := false
460+
start := 0
461+
for i := 0; i < len(inner); i++ {
462+
ch := inner[i]
463+
if inQuotes {
464+
if ch == '"' {
465+
if i+1 < len(inner) && inner[i+1] == '"' {
466+
i++ // skip escaped ""
467+
} else {
468+
inQuotes = false
469+
}
470+
}
471+
continue
472+
}
473+
switch ch {
474+
case '"':
475+
inQuotes = true
476+
case '(':
477+
depth++
478+
case ')':
479+
depth--
480+
case ',':
481+
if depth == 0 {
482+
parts = append(parts, inner[start:i])
483+
start = i + 1
484+
}
485+
}
486+
}
487+
parts = append(parts, inner[start:])
488+
return parts
489+
}
490+
491+
// splitColumnNameAndType splits a TABLE column definition like `"full name" public.mytype`
492+
// into the column name and the type, respecting double-quoted identifiers.
493+
func splitColumnNameAndType(colDef string) (name, typePart string) {
494+
colDef = strings.TrimSpace(colDef)
495+
if colDef == "" {
496+
return "", ""
497+
}
498+
499+
var nameEnd int
500+
if colDef[0] == '"' {
501+
// Quoted identifier: find the closing double-quote
502+
// PostgreSQL escapes embedded quotes as ""
503+
i := 1
504+
for i < len(colDef) {
505+
if colDef[i] == '"' {
506+
if i+1 < len(colDef) && colDef[i+1] == '"' {
507+
i += 2 // skip escaped ""
508+
continue
509+
}
510+
nameEnd = i + 1
511+
break
512+
}
513+
i++
514+
}
515+
if nameEnd == 0 {
516+
// Unterminated quote — treat whole thing as name
517+
return colDef, ""
518+
}
519+
} else {
520+
// Unquoted identifier: ends at first whitespace
521+
nameEnd = strings.IndexFunc(colDef, func(r rune) bool {
522+
return r == ' ' || r == '\t'
523+
})
524+
if nameEnd == -1 {
525+
return colDef, ""
526+
}
527+
}
528+
529+
name = colDef[:nameEnd]
530+
rest := strings.TrimSpace(colDef[nameEnd:])
531+
return name, rest
532+
}
533+
454534
// normalizeFunctionReturnType normalizes function return types, especially TABLE types
455535
func normalizeFunctionReturnType(returnType string) string {
456536
if returnType == "" {
@@ -462,8 +542,8 @@ func normalizeFunctionReturnType(returnType string) string {
462542
// Extract the contents inside TABLE(...)
463543
inner := returnType[6 : len(returnType)-1] // Remove "TABLE(" and ")"
464544

465-
// Split by comma to process each column definition
466-
parts := strings.Split(inner, ",")
545+
// Split by top-level commas (respecting nested parentheses like numeric(10,2))
546+
parts := splitTableColumns(inner)
467547
var normalizedParts []string
468548

469549
for _, part := range parts {
@@ -472,13 +552,11 @@ func normalizeFunctionReturnType(returnType string) string {
472552
continue
473553
}
474554

475-
// Normalize individual column definitions (name type)
476-
fields := strings.Fields(part)
477-
if len(fields) >= 2 {
478-
// Normalize the type part
479-
typePart := strings.Join(fields[1:], " ")
555+
// Split column definition into name and type, respecting quoted identifiers
556+
name, typePart := splitColumnNameAndType(part)
557+
if typePart != "" {
480558
normalizedType := normalizePostgreSQLType(typePart)
481-
normalizedParts = append(normalizedParts, fields[0]+" "+normalizedType)
559+
normalizedParts = append(normalizedParts, name+" "+normalizedType)
482560
} else {
483561
// Just a type, normalize it
484562
normalizedParts = append(normalizedParts, normalizePostgreSQLType(part))
@@ -513,8 +591,26 @@ func stripSchemaFromReturnType(returnType, schema string) string {
513591
}
514592

515593
// Handle TABLE(...) return types - strip schema from individual column types
516-
if strings.HasPrefix(returnType, "TABLE(") {
517-
return returnType // TABLE types are already handled by normalizeFunctionReturnType
594+
if strings.HasPrefix(returnType, "TABLE(") && strings.HasSuffix(returnType, ")") {
595+
inner := returnType[6 : len(returnType)-1] // Remove "TABLE(" and ")"
596+
// Split by top-level commas (respecting nested parentheses like numeric(10,2))
597+
parts := splitTableColumns(inner)
598+
var newParts []string
599+
for _, part := range parts {
600+
part = strings.TrimSpace(part)
601+
if part == "" {
602+
continue
603+
}
604+
// Split column definition into name and type, respecting quoted identifiers
605+
name, typePart := splitColumnNameAndType(part)
606+
if typePart != "" {
607+
strippedType := stripSchemaPrefix(typePart, prefix)
608+
newParts = append(newParts, name+" "+strippedType)
609+
} else {
610+
newParts = append(newParts, part)
611+
}
612+
}
613+
return "TABLE(" + strings.Join(newParts, ", ") + ")"
518614
}
519615

520616
// Direct type name

ir/normalize_test.go

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,166 @@ func TestNormalizeViewStripsSchemaPrefixFromDefinition(t *testing.T) {
161161
}
162162
}
163163

164+
func TestSplitColumnNameAndType(t *testing.T) {
165+
tests := []struct {
166+
name string
167+
colDef string
168+
expectedName string
169+
expectedType string
170+
}{
171+
{"simple", "id integer", "id", "integer"},
172+
{"schema qualified type", "col public.mytype", "col", "public.mytype"},
173+
{"quoted identifier", `"full name" text`, `"full name"`, "text"},
174+
{"quoted with schema type", `"my col" public.mytype`, `"my col"`, "public.mytype"},
175+
{"quoted with escaped quotes", `"it""s" integer`, `"it""s"`, "integer"},
176+
{"name only", "id", "id", ""},
177+
{"empty", "", "", ""},
178+
{"multi-word type", "col character varying", "col", "character varying"},
179+
}
180+
181+
for _, tt := range tests {
182+
t.Run(tt.name, func(t *testing.T) {
183+
name, typePart := splitColumnNameAndType(tt.colDef)
184+
if name != tt.expectedName || typePart != tt.expectedType {
185+
t.Errorf("splitColumnNameAndType(%q) = (%q, %q), want (%q, %q)",
186+
tt.colDef, name, typePart, tt.expectedName, tt.expectedType)
187+
}
188+
})
189+
}
190+
}
191+
192+
func TestSplitTableColumns(t *testing.T) {
193+
tests := []struct {
194+
name string
195+
inner string
196+
expected []string
197+
}{
198+
{
199+
name: "simple columns",
200+
inner: "id integer, name varchar",
201+
expected: []string{"id integer", " name varchar"},
202+
},
203+
{
204+
name: "numeric with precision and scale",
205+
inner: "id integer, amount numeric(10, 2), name varchar",
206+
expected: []string{"id integer", " amount numeric(10, 2)", " name varchar"},
207+
},
208+
{
209+
name: "nested parentheses",
210+
inner: "id integer, val numeric(10, 2), label character varying(100)",
211+
expected: []string{"id integer", " val numeric(10, 2)", " label character varying(100)"},
212+
},
213+
{
214+
name: "quoted identifier with comma",
215+
inner: `"a,b" integer, name varchar`,
216+
expected: []string{`"a,b" integer`, " name varchar"},
217+
},
218+
{
219+
name: "quoted identifier with parenthesis",
220+
inner: `"a(b)" integer, val numeric(10, 2)`,
221+
expected: []string{`"a(b)" integer`, " val numeric(10, 2)"},
222+
},
223+
{
224+
name: "single column",
225+
inner: "id integer",
226+
expected: []string{"id integer"},
227+
},
228+
}
229+
230+
for _, tt := range tests {
231+
t.Run(tt.name, func(t *testing.T) {
232+
result := splitTableColumns(tt.inner)
233+
if len(result) != len(tt.expected) {
234+
t.Fatalf("splitTableColumns(%q) returned %d parts, want %d: %v", tt.inner, len(result), len(tt.expected), result)
235+
}
236+
for i, part := range result {
237+
if part != tt.expected[i] {
238+
t.Errorf("splitTableColumns(%q)[%d] = %q, want %q", tt.inner, i, part, tt.expected[i])
239+
}
240+
}
241+
})
242+
}
243+
}
244+
245+
func TestStripSchemaFromReturnType(t *testing.T) {
246+
tests := []struct {
247+
name string
248+
returnType string
249+
schema string
250+
expected string
251+
}{
252+
{
253+
name: "empty",
254+
returnType: "",
255+
schema: "public",
256+
expected: "",
257+
},
258+
{
259+
name: "simple type no prefix",
260+
returnType: "integer",
261+
schema: "public",
262+
expected: "integer",
263+
},
264+
{
265+
name: "simple type with prefix",
266+
returnType: "public.mytype",
267+
schema: "public",
268+
expected: "mytype",
269+
},
270+
{
271+
name: "SETOF with prefix",
272+
returnType: "SETOF public.actor",
273+
schema: "public",
274+
expected: "SETOF actor",
275+
},
276+
{
277+
name: "TABLE with custom type prefix",
278+
returnType: "TABLE(id uuid, name varchar, created_at public.datetimeoffset)",
279+
schema: "public",
280+
expected: "TABLE(id uuid, name varchar, created_at datetimeoffset)",
281+
},
282+
{
283+
name: "TABLE with multiple custom type prefixes",
284+
returnType: "TABLE(id uuid, created_at public.datetimeoffset, updated_at public.datetimeoffset)",
285+
schema: "public",
286+
expected: "TABLE(id uuid, created_at datetimeoffset, updated_at datetimeoffset)",
287+
},
288+
{
289+
name: "TABLE with no prefix to strip",
290+
returnType: "TABLE(id uuid, name varchar)",
291+
schema: "public",
292+
expected: "TABLE(id uuid, name varchar)",
293+
},
294+
{
295+
name: "TABLE with numeric precision (commas in parens)",
296+
returnType: "TABLE(id integer, amount numeric(10, 2), name public.mytype)",
297+
schema: "public",
298+
expected: "TABLE(id integer, amount numeric(10, 2), name mytype)",
299+
},
300+
{
301+
name: "array type with prefix",
302+
returnType: "public.mytype[]",
303+
schema: "public",
304+
expected: "mytype[]",
305+
},
306+
{
307+
name: "TABLE with quoted column name",
308+
returnType: `TABLE("full name" public.mytype, id uuid)`,
309+
schema: "public",
310+
expected: `TABLE("full name" mytype, id uuid)`,
311+
},
312+
}
313+
314+
for _, tt := range tests {
315+
t.Run(tt.name, func(t *testing.T) {
316+
result := stripSchemaFromReturnType(tt.returnType, tt.schema)
317+
if result != tt.expected {
318+
t.Errorf("stripSchemaFromReturnType(%q, %q) = %q, want %q", tt.returnType, tt.schema, result, tt.expected)
319+
}
320+
})
321+
}
322+
}
323+
164324
func TestNormalizeCheckClause(t *testing.T) {
165325
tests := []struct {
166326
name string

testdata/diff/create_function/issue_360_returns_table_custom_type/diff.sql

Whitespace-only changes.
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
CREATE TYPE datetimeoffset AS (local_time timestamp without time zone, offset_minutes smallint);
2+
3+
CREATE TABLE account_groups (
4+
id uuid NOT NULL,
5+
company_id uuid NOT NULL,
6+
name varchar NOT NULL,
7+
created_at datetimeoffset NOT NULL,
8+
updated_at datetimeoffset NOT NULL
9+
);
10+
11+
CREATE OR REPLACE FUNCTION get_account_group_by_id(
12+
p_group_id uuid
13+
)
14+
RETURNS TABLE(id uuid, company_id uuid, name varchar, created_at datetimeoffset, updated_at datetimeoffset)
15+
LANGUAGE plpgsql
16+
VOLATILE
17+
SECURITY DEFINER
18+
AS $$
19+
BEGIN
20+
RETURN QUERY
21+
SELECT
22+
ag.id,
23+
ag.company_id,
24+
ag.name,
25+
ag.created_at,
26+
ag.updated_at
27+
FROM account_groups ag
28+
WHERE ag.id = p_group_id;
29+
END;
30+
$$;
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
CREATE TYPE datetimeoffset AS (local_time timestamp without time zone, offset_minutes smallint);
2+
3+
CREATE TABLE account_groups (
4+
id uuid NOT NULL,
5+
company_id uuid NOT NULL,
6+
name varchar NOT NULL,
7+
created_at datetimeoffset NOT NULL,
8+
updated_at datetimeoffset NOT NULL
9+
);
10+
11+
CREATE OR REPLACE FUNCTION get_account_group_by_id(
12+
p_group_id uuid
13+
)
14+
RETURNS TABLE(id uuid, company_id uuid, name varchar, created_at datetimeoffset, updated_at datetimeoffset)
15+
LANGUAGE plpgsql
16+
VOLATILE
17+
SECURITY DEFINER
18+
AS $$
19+
BEGIN
20+
RETURN QUERY
21+
SELECT
22+
ag.id,
23+
ag.company_id,
24+
ag.name,
25+
ag.created_at,
26+
ag.updated_at
27+
FROM account_groups ag
28+
WHERE ag.id = p_group_id;
29+
END;
30+
$$;
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"version": "1.0.0",
3+
"pgschema_version": "1.7.4",
4+
"created_at": "1970-01-01T00:00:00Z",
5+
"source_fingerprint": {
6+
"hash": "bc4fc478f2d7ae4cc204de3447d992dface8f485a9227504fed99b21817cb888"
7+
},
8+
"groups": null
9+
}

testdata/diff/create_function/issue_360_returns_table_custom_type/plan.sql

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
No changes detected.

0 commit comments

Comments
 (0)