Skip to content

Commit 96e782d

Browse files
committed
fix: correct multi-schema introspection
1 parent 158d201 commit 96e782d

1 file changed

Lines changed: 86 additions & 52 deletions

File tree

regresql/schema.go

Lines changed: 86 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package regresql
33
import (
44
"database/sql"
55
"fmt"
6+
"strings"
67
)
78

89
type (
@@ -13,6 +14,7 @@ type (
1314

1415
// TableInfo contains metadata about table
1516
TableInfo struct {
17+
Schema string
1618
Name string
1719
Columns map[string]*ColumnInfo
1820
PrimaryKey []string
@@ -43,34 +45,37 @@ type (
4345

4446
// IntrospectSchema queries the database to build schema metadata
4547
func IntrospectSchema(db *sql.DB) (*DatabaseSchema, error) {
46-
schema := &DatabaseSchema{
48+
dbSchema := &DatabaseSchema{
4749
tables: make(map[string]*TableInfo),
4850
}
4951

50-
// Get all tables
52+
// Get all tables (schema-qualified names like "auth.users")
5153
tables, err := getTables(db)
5254
if err != nil {
5355
return nil, fmt.Errorf("failed to get tables: %w", err)
5456
}
5557

5658
// For each table, get columns, primary keys, and foreign keys
57-
for _, tableName := range tables {
59+
for _, qualifiedName := range tables {
60+
schemaName, tableName := parseTableName(qualifiedName)
61+
5862
tableInfo := &TableInfo{
63+
Schema: schemaName,
5964
Name: tableName,
6065
Columns: make(map[string]*ColumnInfo),
6166
}
6267

6368
// Get columns
64-
columns, err := getColumns(db, tableName)
69+
columns, err := getColumns(db, schemaName, tableName)
6570
if err != nil {
66-
return nil, fmt.Errorf("failed to get columns for table '%s': %w", tableName, err)
71+
return nil, fmt.Errorf("failed to get columns for table '%s': %w", qualifiedName, err)
6772
}
6873
tableInfo.Columns = columns
6974

7075
// Get primary keys
71-
primaryKeys, err := getPrimaryKeys(db, tableName)
76+
primaryKeys, err := getPrimaryKeys(db, schemaName, tableName)
7277
if err != nil {
73-
return nil, fmt.Errorf("failed to get primary keys for table '%s': %w", tableName, err)
78+
return nil, fmt.Errorf("failed to get primary keys for table '%s': %w", qualifiedName, err)
7479
}
7580
tableInfo.PrimaryKey = primaryKeys
7681

@@ -82,9 +87,9 @@ func IntrospectSchema(db *sql.DB) (*DatabaseSchema, error) {
8287
}
8388

8489
// Get foreign keys
85-
foreignKeys, err := getForeignKeys(db, tableName)
90+
foreignKeys, err := getForeignKeys(db, schemaName, tableName)
8691
if err != nil {
87-
return nil, fmt.Errorf("failed to get foreign keys for table '%s': %w", tableName, err)
92+
return nil, fmt.Errorf("failed to get foreign keys for table '%s': %w", qualifiedName, err)
8893
}
8994
tableInfo.ForeignKeys = foreignKeys
9095

@@ -97,30 +102,30 @@ func IntrospectSchema(db *sql.DB) (*DatabaseSchema, error) {
97102
}
98103

99104
// Get unique constraints
100-
uniqueCols, err := getUniqueColumns(db, tableName)
105+
uniqueCols, err := getUniqueColumns(db, schemaName, tableName)
101106
if err != nil {
102-
return nil, fmt.Errorf("failed to get unique constraints for table '%s': %w", tableName, err)
107+
return nil, fmt.Errorf("failed to get unique constraints for table '%s': %w", qualifiedName, err)
103108
}
104109
for colName := range uniqueCols {
105110
if col, exists := tableInfo.Columns[colName]; exists {
106111
col.IsUnique = true
107112
}
108113
}
109114

110-
schema.tables[tableName] = tableInfo
115+
dbSchema.tables[qualifiedName] = tableInfo
111116
}
112117

113-
return schema, nil
118+
return dbSchema, nil
114119
}
115120

116-
// getTables retrieves all table names from the database
121+
// getTables retrieves all table names from the database (all user schemas)
117122
func getTables(db *sql.DB) ([]string, error) {
118123
query := `
119-
SELECT table_name
124+
SELECT table_schema, table_name
120125
FROM information_schema.tables
121-
WHERE table_schema = 'public'
126+
WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
122127
AND table_type = 'BASE TABLE'
123-
ORDER BY table_name
128+
ORDER BY table_schema, table_name
124129
`
125130

126131
rows, err := db.Query(query)
@@ -131,18 +136,27 @@ func getTables(db *sql.DB) ([]string, error) {
131136

132137
var tables []string
133138
for rows.Next() {
134-
var tableName string
135-
if err := rows.Scan(&tableName); err != nil {
139+
var schemaName, tableName string
140+
if err := rows.Scan(&schemaName, &tableName); err != nil {
136141
return nil, err
137142
}
138-
tables = append(tables, tableName)
143+
tables = append(tables, schemaName+"."+tableName)
139144
}
140145

141146
return tables, rows.Err()
142147
}
143148

149+
// parseTableName splits a schema-qualified table name into schema and table parts
150+
func parseTableName(name string) (schema, table string) {
151+
parts := strings.SplitN(name, ".", 2)
152+
if len(parts) == 2 {
153+
return parts[0], parts[1]
154+
}
155+
return "public", name
156+
}
157+
144158
// getColumns retrieves column metadata for a table
145-
func getColumns(db *sql.DB, tableName string) (map[string]*ColumnInfo, error) {
159+
func getColumns(db *sql.DB, schemaName, tableName string) (map[string]*ColumnInfo, error) {
146160
query := `
147161
SELECT
148162
column_name,
@@ -151,12 +165,12 @@ func getColumns(db *sql.DB, tableName string) (map[string]*ColumnInfo, error) {
151165
column_default,
152166
character_maximum_length
153167
FROM information_schema.columns
154-
WHERE table_schema = 'public'
155-
AND table_name = $1
168+
WHERE table_schema = $1
169+
AND table_name = $2
156170
ORDER BY ordinal_position
157171
`
158172

159-
rows, err := db.Query(query, tableName)
173+
rows, err := db.Query(query, schemaName, tableName)
160174
if err != nil {
161175
return nil, err
162176
}
@@ -165,11 +179,11 @@ func getColumns(db *sql.DB, tableName string) (map[string]*ColumnInfo, error) {
165179
columns := make(map[string]*ColumnInfo)
166180
for rows.Next() {
167181
var (
168-
columnName string
169-
dataType string
170-
isNullable string
171-
columnDefault sql.NullString
172-
maxLength sql.NullInt64
182+
columnName string
183+
dataType string
184+
isNullable string
185+
columnDefault *string
186+
maxLength *int64
173187
)
174188

175189
if err := rows.Scan(&columnName, &dataType, &isNullable, &columnDefault, &maxLength); err != nil {
@@ -180,14 +194,11 @@ func getColumns(db *sql.DB, tableName string) (map[string]*ColumnInfo, error) {
180194
Name: columnName,
181195
Type: dataType,
182196
IsNullable: isNullable == "YES",
197+
Default: columnDefault,
183198
}
184199

185-
if columnDefault.Valid {
186-
col.Default = &columnDefault.String
187-
}
188-
189-
if maxLength.Valid {
190-
length := int(maxLength.Int64)
200+
if maxLength != nil {
201+
length := int(*maxLength)
191202
col.MaxLength = &length
192203
}
193204

@@ -198,7 +209,9 @@ func getColumns(db *sql.DB, tableName string) (map[string]*ColumnInfo, error) {
198209
}
199210

200211
// getPrimaryKeys retrieves primary key column names for a table
201-
func getPrimaryKeys(db *sql.DB, tableName string) ([]string, error) {
212+
func getPrimaryKeys(db *sql.DB, schemaName, tableName string) ([]string, error) {
213+
// Use schema-qualified name for regclass cast
214+
qualifiedName := schemaName + "." + tableName
202215
query := `
203216
SELECT a.attname
204217
FROM pg_index i
@@ -208,7 +221,7 @@ func getPrimaryKeys(db *sql.DB, tableName string) ([]string, error) {
208221
ORDER BY array_position(i.indkey, a.attnum)
209222
`
210223

211-
rows, err := db.Query(query, tableName)
224+
rows, err := db.Query(query, qualifiedName)
212225
if err != nil {
213226
return nil, err
214227
}
@@ -227,11 +240,12 @@ func getPrimaryKeys(db *sql.DB, tableName string) ([]string, error) {
227240
}
228241

229242
// getForeignKeys retrieves foreign key constraints for a table
230-
func getForeignKeys(db *sql.DB, tableName string) ([]*ForeignKeyInfo, error) {
243+
func getForeignKeys(db *sql.DB, schemaName, tableName string) ([]*ForeignKeyInfo, error) {
231244
query := `
232245
SELECT
233246
tc.constraint_name,
234247
kcu.column_name,
248+
ccu.table_schema AS referenced_schema,
235249
ccu.table_name AS referenced_table,
236250
ccu.column_name AS referenced_column
237251
FROM information_schema.table_constraints AS tc
@@ -240,13 +254,12 @@ func getForeignKeys(db *sql.DB, tableName string) ([]*ForeignKeyInfo, error) {
240254
AND tc.table_schema = kcu.table_schema
241255
JOIN information_schema.constraint_column_usage AS ccu
242256
ON ccu.constraint_name = tc.constraint_name
243-
AND ccu.table_schema = tc.table_schema
244257
WHERE tc.constraint_type = 'FOREIGN KEY'
245258
AND tc.table_name = $1
246-
AND tc.table_schema = 'public'
259+
AND tc.table_schema = $2
247260
`
248261

249-
rows, err := db.Query(query, tableName)
262+
rows, err := db.Query(query, tableName, schemaName)
250263
if err != nil {
251264
return nil, err
252265
}
@@ -255,16 +268,21 @@ func getForeignKeys(db *sql.DB, tableName string) ([]*ForeignKeyInfo, error) {
255268
var foreignKeys []*ForeignKeyInfo
256269
for rows.Next() {
257270
var fk ForeignKeyInfo
258-
if err := rows.Scan(&fk.ConstraintName, &fk.ColumnName, &fk.ReferencedTable, &fk.ReferencedColumn); err != nil {
271+
var refSchema string
272+
if err := rows.Scan(&fk.ConstraintName, &fk.ColumnName, &refSchema, &fk.ReferencedTable, &fk.ReferencedColumn); err != nil {
259273
return nil, err
260274
}
275+
// Store schema-qualified referenced table name
276+
fk.ReferencedTable = refSchema + "." + fk.ReferencedTable
261277
foreignKeys = append(foreignKeys, &fk)
262278
}
263279

264280
return foreignKeys, rows.Err()
265281
}
266282

267-
func getUniqueColumns(db *sql.DB, tableName string) (map[string]bool, error) {
283+
func getUniqueColumns(db *sql.DB, schemaName, tableName string) (map[string]bool, error) {
284+
// Use schema-qualified name for regclass cast
285+
qualifiedName := schemaName + "." + tableName
268286
query := `
269287
SELECT a.attname
270288
FROM pg_index i
@@ -275,7 +293,7 @@ func getUniqueColumns(db *sql.DB, tableName string) (map[string]bool, error) {
275293
AND array_length(i.indkey, 1) = 1
276294
`
277295

278-
rows, err := db.Query(query, tableName)
296+
rows, err := db.Query(query, qualifiedName)
279297
if err != nil {
280298
return nil, err
281299
}
@@ -292,13 +310,19 @@ func getUniqueColumns(db *sql.DB, tableName string) (map[string]bool, error) {
292310
return uniqueCols, rows.Err()
293311
}
294312

295-
// GetTable retrieves table metadata
313+
// GetTable retrieves table metadata by name (schema-qualified or unqualified)
296314
func (ds *DatabaseSchema) GetTable(name string) (*TableInfo, error) {
297-
table, exists := ds.tables[name]
298-
if !exists {
299-
return nil, fmt.Errorf("table not found: %s", name)
315+
// Try exact match first (for schema-qualified names)
316+
if table, exists := ds.tables[name]; exists {
317+
return table, nil
300318
}
301-
return table, nil
319+
// If no dot in name, try public schema
320+
if !strings.Contains(name, ".") {
321+
if table, exists := ds.tables["public."+name]; exists {
322+
return table, nil
323+
}
324+
}
325+
return nil, fmt.Errorf("table not found: %s", name)
302326
}
303327

304328
// GetTables returns all table names
@@ -317,11 +341,14 @@ func (ds *DatabaseSchema) GetForeignKeyDependencies(tableName string) ([]string,
317341
return nil, err
318342
}
319343

344+
// Build the qualified name for self-reference comparison
345+
qualifiedName := table.Schema + "." + table.Name
346+
320347
// Collect unique referenced tables
321348
deps := make(map[string]bool)
322349
for _, fk := range table.ForeignKeys {
323350
// Don't include self-references as dependencies
324-
if fk.ReferencedTable != tableName {
351+
if fk.ReferencedTable != qualifiedName {
325352
deps[fk.ReferencedTable] = true
326353
}
327354
}
@@ -337,6 +364,13 @@ func (ds *DatabaseSchema) GetForeignKeyDependencies(tableName string) ([]string,
337364

338365
// HasTable checks if a table exists in the schema
339366
func (ds *DatabaseSchema) HasTable(name string) bool {
340-
_, exists := ds.tables[name]
341-
return exists
367+
if _, exists := ds.tables[name]; exists {
368+
return true
369+
}
370+
// If no dot in name, try public schema
371+
if !strings.Contains(name, ".") {
372+
_, exists := ds.tables["public."+name]
373+
return exists
374+
}
375+
return false
342376
}

0 commit comments

Comments
 (0)