From 5bc4ab23b52374dbc19defddab41aa52416b2d59 Mon Sep 17 00:00:00 2001 From: SAY-5 Date: Tue, 5 May 2026 03:14:32 -0700 Subject: [PATCH] fix(golang): cast string-enum fields in MySQL copyfrom encoder mysqltsv.AppendValue uses a strict type-switch and rejects string-derived enum types (e.g. ExperienceLocationsType) with "can't encode type X to TSV". Detect string enums in the go-sql-driver-mysql copyfrom template and emit AppendString(string(field)) instead. Fixes #4324 Signed-off-by: SAY-5 --- internal/codegen/golang/gen.go | 1 + .../go-sql-driver-mysql/copyfromCopy.tmpl | 17 ++++-- .../copyfrom_mysql_enum/mysql/go/copyfrom.go | 51 ++++++++++++++++ .../copyfrom_mysql_enum/mysql/go/db.go | 31 ++++++++++ .../copyfrom_mysql_enum/mysql/go/models.go | 59 +++++++++++++++++++ .../copyfrom_mysql_enum/mysql/go/query.sql.go | 16 +++++ .../copyfrom_mysql_enum/mysql/query.sql | 3 + .../copyfrom_mysql_enum/mysql/schema.sql | 4 ++ .../copyfrom_mysql_enum/mysql/sqlc.json | 14 +++++ 9 files changed, 191 insertions(+), 5 deletions(-) create mode 100644 internal/endtoend/testdata/copyfrom_mysql_enum/mysql/go/copyfrom.go create mode 100644 internal/endtoend/testdata/copyfrom_mysql_enum/mysql/go/db.go create mode 100644 internal/endtoend/testdata/copyfrom_mysql_enum/mysql/go/models.go create mode 100644 internal/endtoend/testdata/copyfrom_mysql_enum/mysql/go/query.sql.go create mode 100644 internal/endtoend/testdata/copyfrom_mysql_enum/mysql/query.sql create mode 100644 internal/endtoend/testdata/copyfrom_mysql_enum/mysql/schema.sql create mode 100644 internal/endtoend/testdata/copyfrom_mysql_enum/mysql/sqlc.json diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index 5b81c149c3..30297a3909 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -218,6 +218,7 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum, "imports": i.Imports, "hasImports": i.HasImports, "hasPrefix": strings.HasPrefix, + "hasSuffix": strings.HasSuffix, // These methods are Go specific, they do not belong in the codegen package // (as that is language independent) diff --git a/internal/codegen/golang/templates/go-sql-driver-mysql/copyfromCopy.tmpl b/internal/codegen/golang/templates/go-sql-driver-mysql/copyfromCopy.tmpl index e21475b148..897dbb87fc 100644 --- a/internal/codegen/golang/templates/go-sql-driver-mysql/copyfromCopy.tmpl +++ b/internal/codegen/golang/templates/go-sql-driver-mysql/copyfromCopy.tmpl @@ -7,13 +7,20 @@ func convertRowsFor{{.MethodName}}(w *io.PipeWriter, {{.Arg.SlicePair}}) { e := mysqltsv.NewEncoder(w, {{ len .Arg.CopyFromMySQLFields }}, nil) for _, row := range {{.Arg.Name}} { {{- with $arg := .Arg }} +{{- $enums := $.Enums }} {{- range $arg.CopyFromMySQLFields}} -{{- if eq .Type "string"}} - e.AppendString({{if $arg.Struct}}row.{{.Name}}{{else}}row{{end}}) -{{- else if or (eq .Type "[]byte") (eq .Type "json.RawMessage")}} - e.AppendBytes({{if $arg.Struct}}row.{{.Name}}{{else}}row{{end}}) +{{- $expr := "row" }}{{ if $arg.Struct }}{{ $expr = print "row." .Name }}{{ end }} +{{- $fieldType := .Type }} +{{- $isStringEnum := false }} +{{- range $enums }}{{ if or (eq $fieldType .Name) (hasSuffix $fieldType (print "." .Name)) }}{{ $isStringEnum = true }}{{ end }}{{ end }} +{{- if eq $fieldType "string"}} + e.AppendString({{$expr}}) +{{- else if or (eq $fieldType "[]byte") (eq $fieldType "json.RawMessage")}} + e.AppendBytes({{$expr}}) +{{- else if $isStringEnum }} + e.AppendString(string({{$expr}})) {{- else}} - e.AppendValue({{if $arg.Struct}}row.{{.Name}}{{else}}row{{end}}) + e.AppendValue({{$expr}}) {{- end}} {{- end}} {{- end}} diff --git a/internal/endtoend/testdata/copyfrom_mysql_enum/mysql/go/copyfrom.go b/internal/endtoend/testdata/copyfrom_mysql_enum/mysql/go/copyfrom.go new file mode 100644 index 0000000000..ed5bbaa16e --- /dev/null +++ b/internal/endtoend/testdata/copyfrom_mysql_enum/mysql/go/copyfrom.go @@ -0,0 +1,51 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 +// source: copyfrom.go + +package querytest + +import ( + "context" + "fmt" + "io" + "sync/atomic" + + "github.com/go-sql-driver/mysql" + "github.com/hexon/mysqltsv" +) + +var readerHandlerSequenceForUpsertExperienceLocations uint32 = 1 + +func convertRowsForUpsertExperienceLocations(w *io.PipeWriter, arg []UpsertExperienceLocationsParams) { + e := mysqltsv.NewEncoder(w, 2, nil) + for _, row := range arg { + e.AppendString(row.LocationID) + e.AppendString(string(row.Type)) + } + w.CloseWithError(e.Close()) +} + +// UpsertExperienceLocations uses MySQL's LOAD DATA LOCAL INFILE and is not atomic. +// +// Errors and duplicate keys are treated as warnings and insertion will +// continue, even without an error for some cases. Use this in a transaction +// and use SHOW WARNINGS to check for any problems and roll back if you want to. +// +// Check the documentation for more information: +// https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-error-handling +func (q *Queries) UpsertExperienceLocations(ctx context.Context, arg []UpsertExperienceLocationsParams) (int64, error) { + pr, pw := io.Pipe() + defer pr.Close() + rh := fmt.Sprintf("UpsertExperienceLocations_%d", atomic.AddUint32(&readerHandlerSequenceForUpsertExperienceLocations, 1)) + mysql.RegisterReaderHandler(rh, func() io.Reader { return pr }) + defer mysql.DeregisterReaderHandler(rh) + go convertRowsForUpsertExperienceLocations(pw, arg) + // The string interpolation is necessary because LOAD DATA INFILE requires + // the file name to be given as a literal string. + result, err := q.db.ExecContext(ctx, fmt.Sprintf("LOAD DATA LOCAL INFILE '%s' INTO TABLE `experience_locations` %s (location_id, type)", "Reader::"+rh, mysqltsv.Escaping)) + if err != nil { + return 0, err + } + return result.RowsAffected() +} diff --git a/internal/endtoend/testdata/copyfrom_mysql_enum/mysql/go/db.go b/internal/endtoend/testdata/copyfrom_mysql_enum/mysql/go/db.go new file mode 100644 index 0000000000..80dd6ab1f6 --- /dev/null +++ b/internal/endtoend/testdata/copyfrom_mysql_enum/mysql/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/copyfrom_mysql_enum/mysql/go/models.go b/internal/endtoend/testdata/copyfrom_mysql_enum/mysql/go/models.go new file mode 100644 index 0000000000..296e688aa1 --- /dev/null +++ b/internal/endtoend/testdata/copyfrom_mysql_enum/mysql/go/models.go @@ -0,0 +1,59 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +import ( + "database/sql/driver" + "fmt" +) + +type ExperienceLocationsType string + +const ( + ExperienceLocationsTypeStartPoint ExperienceLocationsType = "start_point" + ExperienceLocationsTypePickupPoint ExperienceLocationsType = "pickup_point" + ExperienceLocationsTypeRedemptionPoint ExperienceLocationsType = "redemption_point" + ExperienceLocationsTypeEndPoint ExperienceLocationsType = "end_point" +) + +func (e *ExperienceLocationsType) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = ExperienceLocationsType(s) + case string: + *e = ExperienceLocationsType(s) + default: + return fmt.Errorf("unsupported scan type for ExperienceLocationsType: %T", src) + } + return nil +} + +type NullExperienceLocationsType struct { + ExperienceLocationsType ExperienceLocationsType + Valid bool // Valid is true if ExperienceLocationsType is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullExperienceLocationsType) Scan(value interface{}) error { + if value == nil { + ns.ExperienceLocationsType, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.ExperienceLocationsType.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullExperienceLocationsType) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.ExperienceLocationsType), nil +} + +type ExperienceLocation struct { + LocationID string + Type ExperienceLocationsType +} diff --git a/internal/endtoend/testdata/copyfrom_mysql_enum/mysql/go/query.sql.go b/internal/endtoend/testdata/copyfrom_mysql_enum/mysql/go/query.sql.go new file mode 100644 index 0000000000..84a05c3fb6 --- /dev/null +++ b/internal/endtoend/testdata/copyfrom_mysql_enum/mysql/go/query.sql.go @@ -0,0 +1,16 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 +// source: query.sql + +package querytest + +const upsertExperienceLocations = `-- name: UpsertExperienceLocations :copyfrom +REPLACE INTO experience_locations (location_id, type) +VALUES (?, ?) +` + +type UpsertExperienceLocationsParams struct { + LocationID string + Type ExperienceLocationsType +} diff --git a/internal/endtoend/testdata/copyfrom_mysql_enum/mysql/query.sql b/internal/endtoend/testdata/copyfrom_mysql_enum/mysql/query.sql new file mode 100644 index 0000000000..0a1e1a5722 --- /dev/null +++ b/internal/endtoend/testdata/copyfrom_mysql_enum/mysql/query.sql @@ -0,0 +1,3 @@ +-- name: UpsertExperienceLocations :copyfrom +REPLACE INTO experience_locations (location_id, type) +VALUES (?, ?); diff --git a/internal/endtoend/testdata/copyfrom_mysql_enum/mysql/schema.sql b/internal/endtoend/testdata/copyfrom_mysql_enum/mysql/schema.sql new file mode 100644 index 0000000000..62df1e5dff --- /dev/null +++ b/internal/endtoend/testdata/copyfrom_mysql_enum/mysql/schema.sql @@ -0,0 +1,4 @@ +CREATE TABLE experience_locations ( + location_id varchar(512) NOT NULL, + type ENUM('start_point', 'pickup_point', 'redemption_point', 'end_point') NOT NULL +); diff --git a/internal/endtoend/testdata/copyfrom_mysql_enum/mysql/sqlc.json b/internal/endtoend/testdata/copyfrom_mysql_enum/mysql/sqlc.json new file mode 100644 index 0000000000..7dabfeef72 --- /dev/null +++ b/internal/endtoend/testdata/copyfrom_mysql_enum/mysql/sqlc.json @@ -0,0 +1,14 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "sql_package": "database/sql", + "sql_driver": "github.com/go-sql-driver/mysql", + "engine": "mysql", + "name": "querytest", + "schema": "schema.sql", + "queries": "query.sql" + } + ] +}