|
1 | 1 | package postgres |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "fmt" |
4 | 5 | "reflect" |
| 6 | + "strings" |
5 | 7 | "testing" |
| 8 | + |
| 9 | + "github.com/jackc/pgx/v5/pgconn" |
6 | 10 | ) |
7 | 11 |
|
8 | 12 | func TestSplitDollarQuotedSegments(t *testing.T) { |
@@ -291,3 +295,70 @@ func TestStripSchemaQualifications_PreservesStringLiterals(t *testing.T) { |
291 | 295 | }) |
292 | 296 | } |
293 | 297 | } |
| 298 | + |
| 299 | +func TestEnhanceApplyError(t *testing.T) { |
| 300 | + sql := "CREATE TABLE foo (id int);\nCREATE TABLE bar (\n name text\n);\nSELECT 1;\nCREATE TABLE baz (id int);" |
| 301 | + |
| 302 | + t.Run("pgError with position", func(t *testing.T) { |
| 303 | + // Position points to "SELECT" on line 5 |
| 304 | + pos := int32(strings.Index(sql, "SELECT 1") + 1) // 1-based |
| 305 | + pgErr := &pgconn.PgError{ |
| 306 | + Message: "syntax error at or near \"SELECT\"", |
| 307 | + Code: "42601", |
| 308 | + Position: pos, |
| 309 | + } |
| 310 | + enhanced := enhanceApplyError(pgErr, sql) |
| 311 | + errMsg := enhanced.Error() |
| 312 | + |
| 313 | + if !strings.Contains(errMsg, "line 5") { |
| 314 | + t.Errorf("expected error to mention line 5, got: %s", errMsg) |
| 315 | + } |
| 316 | + if !strings.Contains(errMsg, "SELECT 1") { |
| 317 | + t.Errorf("expected error to contain the offending line, got: %s", errMsg) |
| 318 | + } |
| 319 | + // Should still contain original error |
| 320 | + if !strings.Contains(errMsg, "syntax error") { |
| 321 | + t.Errorf("expected error to contain original message, got: %s", errMsg) |
| 322 | + } |
| 323 | + }) |
| 324 | + |
| 325 | + t.Run("multi-byte UTF-8 position", func(t *testing.T) { |
| 326 | + // PostgreSQL Position counts characters, not bytes. |
| 327 | + // "café" is 4 characters but 5 bytes (é is 2 bytes in UTF-8). |
| 328 | + mbSQL := "-- café\nSELECT 1;" |
| 329 | + // "SELECT" starts at character position 9 (1-based): "-- café\n" = 8 chars |
| 330 | + pgErr := &pgconn.PgError{ |
| 331 | + Message: "syntax error", |
| 332 | + Code: "42601", |
| 333 | + Position: 9, |
| 334 | + } |
| 335 | + enhanced := enhanceApplyError(pgErr, mbSQL) |
| 336 | + errMsg := enhanced.Error() |
| 337 | + |
| 338 | + if !strings.Contains(errMsg, "line 2, column 1") { |
| 339 | + t.Errorf("expected line 2, column 1 for multi-byte SQL, got: %s", errMsg) |
| 340 | + } |
| 341 | + if !strings.Contains(errMsg, "SELECT 1") { |
| 342 | + t.Errorf("expected snippet to contain the error line, got: %s", errMsg) |
| 343 | + } |
| 344 | + }) |
| 345 | + |
| 346 | + t.Run("non-pg error passes through", func(t *testing.T) { |
| 347 | + origErr := fmt.Errorf("some other error") |
| 348 | + result := enhanceApplyError(origErr, sql) |
| 349 | + if result != origErr { |
| 350 | + t.Errorf("expected same error instance, got: %s", result.Error()) |
| 351 | + } |
| 352 | + }) |
| 353 | + |
| 354 | + t.Run("pgError without position passes through", func(t *testing.T) { |
| 355 | + pgErr := &pgconn.PgError{ |
| 356 | + Message: "some error", |
| 357 | + Code: "42601", |
| 358 | + } |
| 359 | + result := enhanceApplyError(pgErr, sql) |
| 360 | + if result != pgErr { |
| 361 | + t.Errorf("expected same error instance, got: %s", result.Error()) |
| 362 | + } |
| 363 | + }) |
| 364 | +} |
0 commit comments