@@ -3,6 +3,7 @@ package regresql
33import (
44 "database/sql"
55 "fmt"
6+ "strings"
67)
78
89type (
1314
1415 // TableInfo contains metadata about table
1516 TableInfo struct {
17+ Schema string
1618 Name string
1719 Columns map [string ]* ColumnInfo
1820 PrimaryKey []string
@@ -43,34 +45,37 @@ type (
4345
4446// IntrospectSchema queries the database to build schema metadata
4547func IntrospectSchema (db * sql.DB ) (* DatabaseSchema , error ) {
46- schema := & DatabaseSchema {
48+ dbSchema := & DatabaseSchema {
4749 tables : make (map [string ]* TableInfo ),
4850 }
4951
50- // Get all tables
52+ // Get all tables (schema-qualified names like "auth.users")
5153 tables , err := getTables (db )
5254 if err != nil {
5355 return nil , fmt .Errorf ("failed to get tables: %w" , err )
5456 }
5557
5658 // For each table, get columns, primary keys, and foreign keys
57- for _ , tableName := range tables {
59+ for _ , qualifiedName := range tables {
60+ schemaName , tableName := parseTableName (qualifiedName )
61+
5862 tableInfo := & TableInfo {
63+ Schema : schemaName ,
5964 Name : tableName ,
6065 Columns : make (map [string ]* ColumnInfo ),
6166 }
6267
6368 // Get columns
64- columns , err := getColumns (db , tableName )
69+ columns , err := getColumns (db , schemaName , tableName )
6570 if err != nil {
66- return nil , fmt .Errorf ("failed to get columns for table '%s': %w" , tableName , err )
71+ return nil , fmt .Errorf ("failed to get columns for table '%s': %w" , qualifiedName , err )
6772 }
6873 tableInfo .Columns = columns
6974
7075 // Get primary keys
71- primaryKeys , err := getPrimaryKeys (db , tableName )
76+ primaryKeys , err := getPrimaryKeys (db , schemaName , tableName )
7277 if err != nil {
73- return nil , fmt .Errorf ("failed to get primary keys for table '%s': %w" , tableName , err )
78+ return nil , fmt .Errorf ("failed to get primary keys for table '%s': %w" , qualifiedName , err )
7479 }
7580 tableInfo .PrimaryKey = primaryKeys
7681
@@ -82,9 +87,9 @@ func IntrospectSchema(db *sql.DB) (*DatabaseSchema, error) {
8287 }
8388
8489 // Get foreign keys
85- foreignKeys , err := getForeignKeys (db , tableName )
90+ foreignKeys , err := getForeignKeys (db , schemaName , tableName )
8691 if err != nil {
87- return nil , fmt .Errorf ("failed to get foreign keys for table '%s': %w" , tableName , err )
92+ return nil , fmt .Errorf ("failed to get foreign keys for table '%s': %w" , qualifiedName , err )
8893 }
8994 tableInfo .ForeignKeys = foreignKeys
9095
@@ -97,30 +102,30 @@ func IntrospectSchema(db *sql.DB) (*DatabaseSchema, error) {
97102 }
98103
99104 // Get unique constraints
100- uniqueCols , err := getUniqueColumns (db , tableName )
105+ uniqueCols , err := getUniqueColumns (db , schemaName , tableName )
101106 if err != nil {
102- return nil , fmt .Errorf ("failed to get unique constraints for table '%s': %w" , tableName , err )
107+ return nil , fmt .Errorf ("failed to get unique constraints for table '%s': %w" , qualifiedName , err )
103108 }
104109 for colName := range uniqueCols {
105110 if col , exists := tableInfo .Columns [colName ]; exists {
106111 col .IsUnique = true
107112 }
108113 }
109114
110- schema .tables [tableName ] = tableInfo
115+ dbSchema .tables [qualifiedName ] = tableInfo
111116 }
112117
113- return schema , nil
118+ return dbSchema , nil
114119}
115120
116- // getTables retrieves all table names from the database
121+ // getTables retrieves all table names from the database (all user schemas)
117122func getTables (db * sql.DB ) ([]string , error ) {
118123 query := `
119- SELECT table_name
124+ SELECT table_schema, table_name
120125 FROM information_schema.tables
121- WHERE table_schema = 'public'
126+ WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
122127 AND table_type = 'BASE TABLE'
123- ORDER BY table_name
128+ ORDER BY table_schema, table_name
124129 `
125130
126131 rows , err := db .Query (query )
@@ -131,18 +136,27 @@ func getTables(db *sql.DB) ([]string, error) {
131136
132137 var tables []string
133138 for rows .Next () {
134- var tableName string
135- if err := rows .Scan (& tableName ); err != nil {
139+ var schemaName , tableName string
140+ if err := rows .Scan (& schemaName , & tableName ); err != nil {
136141 return nil , err
137142 }
138- tables = append (tables , tableName )
143+ tables = append (tables , schemaName + "." + tableName )
139144 }
140145
141146 return tables , rows .Err ()
142147}
143148
149+ // parseTableName splits a schema-qualified table name into schema and table parts
150+ func parseTableName (name string ) (schema , table string ) {
151+ parts := strings .SplitN (name , "." , 2 )
152+ if len (parts ) == 2 {
153+ return parts [0 ], parts [1 ]
154+ }
155+ return "public" , name
156+ }
157+
144158// getColumns retrieves column metadata for a table
145- func getColumns (db * sql.DB , tableName string ) (map [string ]* ColumnInfo , error ) {
159+ func getColumns (db * sql.DB , schemaName , tableName string ) (map [string ]* ColumnInfo , error ) {
146160 query := `
147161 SELECT
148162 column_name,
@@ -151,12 +165,12 @@ func getColumns(db *sql.DB, tableName string) (map[string]*ColumnInfo, error) {
151165 column_default,
152166 character_maximum_length
153167 FROM information_schema.columns
154- WHERE table_schema = 'public'
155- AND table_name = $1
168+ WHERE table_schema = $1
169+ AND table_name = $2
156170 ORDER BY ordinal_position
157171 `
158172
159- rows , err := db .Query (query , tableName )
173+ rows , err := db .Query (query , schemaName , tableName )
160174 if err != nil {
161175 return nil , err
162176 }
@@ -165,11 +179,11 @@ func getColumns(db *sql.DB, tableName string) (map[string]*ColumnInfo, error) {
165179 columns := make (map [string ]* ColumnInfo )
166180 for rows .Next () {
167181 var (
168- columnName string
169- dataType string
170- isNullable string
171- columnDefault sql. NullString
172- maxLength sql. NullInt64
182+ columnName string
183+ dataType string
184+ isNullable string
185+ columnDefault * string
186+ maxLength * int64
173187 )
174188
175189 if err := rows .Scan (& columnName , & dataType , & isNullable , & columnDefault , & maxLength ); err != nil {
@@ -180,14 +194,11 @@ func getColumns(db *sql.DB, tableName string) (map[string]*ColumnInfo, error) {
180194 Name : columnName ,
181195 Type : dataType ,
182196 IsNullable : isNullable == "YES" ,
197+ Default : columnDefault ,
183198 }
184199
185- if columnDefault .Valid {
186- col .Default = & columnDefault .String
187- }
188-
189- if maxLength .Valid {
190- length := int (maxLength .Int64 )
200+ if maxLength != nil {
201+ length := int (* maxLength )
191202 col .MaxLength = & length
192203 }
193204
@@ -198,7 +209,9 @@ func getColumns(db *sql.DB, tableName string) (map[string]*ColumnInfo, error) {
198209}
199210
200211// getPrimaryKeys retrieves primary key column names for a table
201- func getPrimaryKeys (db * sql.DB , tableName string ) ([]string , error ) {
212+ func getPrimaryKeys (db * sql.DB , schemaName , tableName string ) ([]string , error ) {
213+ // Use schema-qualified name for regclass cast
214+ qualifiedName := schemaName + "." + tableName
202215 query := `
203216 SELECT a.attname
204217 FROM pg_index i
@@ -208,7 +221,7 @@ func getPrimaryKeys(db *sql.DB, tableName string) ([]string, error) {
208221 ORDER BY array_position(i.indkey, a.attnum)
209222 `
210223
211- rows , err := db .Query (query , tableName )
224+ rows , err := db .Query (query , qualifiedName )
212225 if err != nil {
213226 return nil , err
214227 }
@@ -227,11 +240,12 @@ func getPrimaryKeys(db *sql.DB, tableName string) ([]string, error) {
227240}
228241
229242// getForeignKeys retrieves foreign key constraints for a table
230- func getForeignKeys (db * sql.DB , tableName string ) ([]* ForeignKeyInfo , error ) {
243+ func getForeignKeys (db * sql.DB , schemaName , tableName string ) ([]* ForeignKeyInfo , error ) {
231244 query := `
232245 SELECT
233246 tc.constraint_name,
234247 kcu.column_name,
248+ ccu.table_schema AS referenced_schema,
235249 ccu.table_name AS referenced_table,
236250 ccu.column_name AS referenced_column
237251 FROM information_schema.table_constraints AS tc
@@ -240,13 +254,12 @@ func getForeignKeys(db *sql.DB, tableName string) ([]*ForeignKeyInfo, error) {
240254 AND tc.table_schema = kcu.table_schema
241255 JOIN information_schema.constraint_column_usage AS ccu
242256 ON ccu.constraint_name = tc.constraint_name
243- AND ccu.table_schema = tc.table_schema
244257 WHERE tc.constraint_type = 'FOREIGN KEY'
245258 AND tc.table_name = $1
246- AND tc.table_schema = 'public'
259+ AND tc.table_schema = $2
247260 `
248261
249- rows , err := db .Query (query , tableName )
262+ rows , err := db .Query (query , tableName , schemaName )
250263 if err != nil {
251264 return nil , err
252265 }
@@ -255,16 +268,21 @@ func getForeignKeys(db *sql.DB, tableName string) ([]*ForeignKeyInfo, error) {
255268 var foreignKeys []* ForeignKeyInfo
256269 for rows .Next () {
257270 var fk ForeignKeyInfo
258- if err := rows .Scan (& fk .ConstraintName , & fk .ColumnName , & fk .ReferencedTable , & fk .ReferencedColumn ); err != nil {
271+ var refSchema string
272+ if err := rows .Scan (& fk .ConstraintName , & fk .ColumnName , & refSchema , & fk .ReferencedTable , & fk .ReferencedColumn ); err != nil {
259273 return nil , err
260274 }
275+ // Store schema-qualified referenced table name
276+ fk .ReferencedTable = refSchema + "." + fk .ReferencedTable
261277 foreignKeys = append (foreignKeys , & fk )
262278 }
263279
264280 return foreignKeys , rows .Err ()
265281}
266282
267- func getUniqueColumns (db * sql.DB , tableName string ) (map [string ]bool , error ) {
283+ func getUniqueColumns (db * sql.DB , schemaName , tableName string ) (map [string ]bool , error ) {
284+ // Use schema-qualified name for regclass cast
285+ qualifiedName := schemaName + "." + tableName
268286 query := `
269287 SELECT a.attname
270288 FROM pg_index i
@@ -275,7 +293,7 @@ func getUniqueColumns(db *sql.DB, tableName string) (map[string]bool, error) {
275293 AND array_length(i.indkey, 1) = 1
276294 `
277295
278- rows , err := db .Query (query , tableName )
296+ rows , err := db .Query (query , qualifiedName )
279297 if err != nil {
280298 return nil , err
281299 }
@@ -292,13 +310,19 @@ func getUniqueColumns(db *sql.DB, tableName string) (map[string]bool, error) {
292310 return uniqueCols , rows .Err ()
293311}
294312
295- // GetTable retrieves table metadata
313+ // GetTable retrieves table metadata by name (schema-qualified or unqualified)
296314func (ds * DatabaseSchema ) GetTable (name string ) (* TableInfo , error ) {
297- table , exists := ds . tables [ name ]
298- if ! exists {
299- return nil , fmt . Errorf ( " table not found: %s" , name )
315+ // Try exact match first (for schema-qualified names)
316+ if table , exists := ds . tables [ name ]; exists {
317+ return table , nil
300318 }
301- return table , nil
319+ // If no dot in name, try public schema
320+ if ! strings .Contains (name , "." ) {
321+ if table , exists := ds .tables ["public." + name ]; exists {
322+ return table , nil
323+ }
324+ }
325+ return nil , fmt .Errorf ("table not found: %s" , name )
302326}
303327
304328// GetTables returns all table names
@@ -317,11 +341,14 @@ func (ds *DatabaseSchema) GetForeignKeyDependencies(tableName string) ([]string,
317341 return nil , err
318342 }
319343
344+ // Build the qualified name for self-reference comparison
345+ qualifiedName := table .Schema + "." + table .Name
346+
320347 // Collect unique referenced tables
321348 deps := make (map [string ]bool )
322349 for _ , fk := range table .ForeignKeys {
323350 // Don't include self-references as dependencies
324- if fk .ReferencedTable != tableName {
351+ if fk .ReferencedTable != qualifiedName {
325352 deps [fk .ReferencedTable ] = true
326353 }
327354 }
@@ -337,6 +364,13 @@ func (ds *DatabaseSchema) GetForeignKeyDependencies(tableName string) ([]string,
337364
338365// HasTable checks if a table exists in the schema
339366func (ds * DatabaseSchema ) HasTable (name string ) bool {
340- _ , exists := ds .tables [name ]
341- return exists
367+ if _ , exists := ds .tables [name ]; exists {
368+ return true
369+ }
370+ // If no dot in name, try public schema
371+ if ! strings .Contains (name , "." ) {
372+ _ , exists := ds .tables ["public." + name ]
373+ return exists
374+ }
375+ return false
342376}
0 commit comments