From eda87b86f49b82894757fc4670d8077e98cb0d49 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 28 Apr 2026 05:33:12 +0000 Subject: [PATCH 1/4] Introduce internal/api package with Generate entry point The new package mirrors esbuild's Build API: a single api.Generate(ctx, api.GenerateOptions{}) call returns a GenerateResult containing the generated files and any errors. Most of cmd/generate.go's logic moves here as unexported helpers; the only exported names are Generate, GenerateOptions, and GenerateResult. cmd.Generate is now a thin wrapper that translates the CLI's Options struct into api.GenerateOptions. The endtoend tests call api.Generate directly for TestExamples, TestReplay (generate command), and the benchmarks. https://claude.ai/code/session_01RCzB2JR5Y5ScFDUmwcxGVZ --- internal/api/api.go | 11 ++ internal/api/codegen.go | 108 +++++++++++++ internal/api/config.go | 93 ++++++++++++ internal/api/generate.go | 174 +++++++++++++++++++++ internal/api/parse.go | 63 ++++++++ internal/api/process.go | 109 ++++++++++++++ internal/api/shim.go | 233 +++++++++++++++++++++++++++++ internal/cmd/generate.go | 206 ++----------------------- internal/endtoend/endtoend_test.go | 46 +++--- 9 files changed, 833 insertions(+), 210 deletions(-) create mode 100644 internal/api/api.go create mode 100644 internal/api/codegen.go create mode 100644 internal/api/config.go create mode 100644 internal/api/generate.go create mode 100644 internal/api/parse.go create mode 100644 internal/api/process.go create mode 100644 internal/api/shim.go diff --git a/internal/api/api.go b/internal/api/api.go new file mode 100644 index 0000000000..c1011de572 --- /dev/null +++ b/internal/api/api.go @@ -0,0 +1,11 @@ +// Package api is intended to be the future public API for sqlc. +// +// The shape of this package is inspired by esbuild's Build API +// (https://pkg.go.dev/github.com/evanw/esbuild/pkg/api#hdr-Build_API): a small +// surface area of options structs and result structs that lets callers drive +// sqlc programmatically without going through the CLI. +// +// Today the package lives under internal/ while the API stabilises. Once the +// surface settles it is expected to graduate to pkg/api so it can be imported +// by external Go programs. +package api diff --git a/internal/api/codegen.go b/internal/api/codegen.go new file mode 100644 index 0000000000..d733c3b2d4 --- /dev/null +++ b/internal/api/codegen.go @@ -0,0 +1,108 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "runtime/trace" + + "google.golang.org/grpc" + + "github.com/sqlc-dev/sqlc/internal/codegen/golang" + genjson "github.com/sqlc-dev/sqlc/internal/codegen/json" + "github.com/sqlc-dev/sqlc/internal/compiler" + "github.com/sqlc-dev/sqlc/internal/config" + "github.com/sqlc-dev/sqlc/internal/config/convert" + "github.com/sqlc-dev/sqlc/internal/ext" + "github.com/sqlc-dev/sqlc/internal/ext/process" + "github.com/sqlc-dev/sqlc/internal/ext/wasm" + "github.com/sqlc-dev/sqlc/internal/plugin" +) + +func findPlugin(conf config.Config, name string) (*config.Plugin, error) { + for _, plug := range conf.Plugins { + if plug.Name == name { + return &plug, nil + } + } + return nil, fmt.Errorf("plugin not found") +} + +func codegen(ctx context.Context, combo config.CombinedSettings, sql outputPair, result *compiler.Result) (string, *plugin.GenerateResponse, error) { + defer trace.StartRegion(ctx, "codegen").End() + req := codeGenRequest(result, combo) + var handler grpc.ClientConnInterface + var out string + switch { + case sql.Plugin != nil: + out = sql.Plugin.Out + plug, err := findPlugin(combo.Global, sql.Plugin.Plugin) + if err != nil { + return "", nil, fmt.Errorf("plugin not found: %s", err) + } + + switch { + case plug.Process != nil: + handler = &process.Runner{ + Cmd: plug.Process.Cmd, + Env: plug.Env, + Format: plug.Process.Format, + } + case plug.WASM != nil: + handler = &wasm.Runner{ + URL: plug.WASM.URL, + SHA256: plug.WASM.SHA256, + Env: plug.Env, + } + default: + return "", nil, fmt.Errorf("unsupported plugin type") + } + + opts, err := convert.YAMLtoJSON(sql.Plugin.Options) + if err != nil { + return "", nil, fmt.Errorf("invalid plugin options: %w", err) + } + req.PluginOptions = opts + + global, found := combo.Global.Options[plug.Name] + if found { + opts, err := convert.YAMLtoJSON(global) + if err != nil { + return "", nil, fmt.Errorf("invalid global options: %w", err) + } + req.GlobalOptions = opts + } + + case sql.Gen.Go != nil: + out = combo.Go.Out + handler = ext.HandleFunc(golang.Generate) + opts, err := json.Marshal(sql.Gen.Go) + if err != nil { + return "", nil, fmt.Errorf("opts marshal failed: %w", err) + } + req.PluginOptions = opts + + if combo.Global.Overrides.Go != nil { + opts, err := json.Marshal(combo.Global.Overrides.Go) + if err != nil { + return "", nil, fmt.Errorf("opts marshal failed: %w", err) + } + req.GlobalOptions = opts + } + + case sql.Gen.JSON != nil: + out = combo.JSON.Out + handler = ext.HandleFunc(genjson.Generate) + opts, err := json.Marshal(sql.Gen.JSON) + if err != nil { + return "", nil, fmt.Errorf("opts marshal failed: %w", err) + } + req.PluginOptions = opts + + default: + return "", nil, fmt.Errorf("missing language backend") + } + client := plugin.NewCodegenServiceClient(handler) + resp, err := client.Generate(ctx, req) + return out, resp, err +} diff --git a/internal/api/config.go b/internal/api/config.go new file mode 100644 index 0000000000..00eef38d19 --- /dev/null +++ b/internal/api/config.go @@ -0,0 +1,93 @@ +package api + +import ( + "errors" + "fmt" + "io" + "os" + "path/filepath" + + "github.com/sqlc-dev/sqlc/internal/config" +) + +const errMessageNoVersion = `The configuration file must have a version number. +Set the version to 1 or 2 at the top of sqlc.json: + +{ + "version": "1" + ... +} +` + +const errMessageUnknownVersion = `The configuration file has an invalid version number. +The supported version can only be "1" or "2". +` + +const errMessageNoPackages = `No packages are configured` + +func readConfig(stderr io.Writer, dir, filename string) (string, *config.Config, error) { + configPath := "" + if filename != "" { + configPath = filepath.Join(dir, filename) + } else { + var yamlMissing, jsonMissing, ymlMissing bool + yamlPath := filepath.Join(dir, "sqlc.yaml") + ymlPath := filepath.Join(dir, "sqlc.yml") + jsonPath := filepath.Join(dir, "sqlc.json") + + if _, err := os.Stat(yamlPath); os.IsNotExist(err) { + yamlMissing = true + } + if _, err := os.Stat(jsonPath); os.IsNotExist(err) { + jsonMissing = true + } + + if _, err := os.Stat(ymlPath); os.IsNotExist(err) { + ymlMissing = true + } + + if yamlMissing && ymlMissing && jsonMissing { + fmt.Fprintln(stderr, "error parsing configuration files. sqlc.(yaml|yml) or sqlc.json: file does not exist") + return "", nil, errors.New("config file missing") + } + + if (!yamlMissing || !ymlMissing) && !jsonMissing { + fmt.Fprintln(stderr, "error: both sqlc.json and sqlc.(yaml|yml) files present") + return "", nil, errors.New("sqlc.json and sqlc.(yaml|yml) present") + } + + if jsonMissing { + if yamlMissing { + configPath = ymlPath + } else { + configPath = yamlPath + } + } else { + configPath = jsonPath + } + } + + base := filepath.Base(configPath) + file, err := os.Open(configPath) + if err != nil { + fmt.Fprintf(stderr, "error parsing %s: file does not exist\n", base) + return "", nil, err + } + defer file.Close() + + conf, err := config.ParseConfig(file) + if err != nil { + switch err { + case config.ErrMissingVersion: + fmt.Fprint(stderr, errMessageNoVersion) + case config.ErrUnknownVersion: + fmt.Fprint(stderr, errMessageUnknownVersion) + case config.ErrNoPackages: + fmt.Fprint(stderr, errMessageNoPackages) + } + fmt.Fprintf(stderr, "error parsing %s: %s\n", base, err) + return "", nil, err + } + + return configPath, &conf, nil +} diff --git a/internal/api/generate.go b/internal/api/generate.go new file mode 100644 index 0000000000..5d10f015d1 --- /dev/null +++ b/internal/api/generate.go @@ -0,0 +1,174 @@ +package api + +import ( + "context" + "errors" + "fmt" + "io" + "path/filepath" + "strings" + "sync" + + "github.com/sqlc-dev/sqlc/internal/compiler" + "github.com/sqlc-dev/sqlc/internal/config" +) + +// errPluginProcessDisabled is returned when the configuration uses a process +// plugin but the caller has disabled them via GenerateOptions.DisableProcessPlugins. +var errPluginProcessDisabled = errors.New("plugin: process-based plugins disabled via SQLCDEBUG=processplugins=0") + +// GenerateOptions controls a single Generate invocation. +type GenerateOptions struct { + // Dir is the working directory used to resolve the config file and any + // relative schema/query paths within it. + Dir string + + // File is the configuration filename to use, relative to Dir. When empty, + // Generate looks for sqlc.yaml, sqlc.yml, or sqlc.json in Dir. + File string + + // Stderr receives diagnostic output. If nil, output is discarded. + Stderr io.Writer + + // DisableProcessPlugins, when true, causes Generate to fail if the + // configuration uses a process-based plugin. The sqlc CLI sets this from + // SQLCDEBUG=processplugins=0. + DisableProcessPlugins bool + + // MutateConfig is called after the configuration is parsed but before it is + // validated. It is intended for tests. + MutateConfig func(*config.Config) +} + +// GenerateResult is the outcome of a Generate call. Files maps absolute output +// paths to file contents; callers are responsible for writing them to disk if +// desired. Errors collects any errors encountered during code generation. +type GenerateResult struct { + // Files maps absolute output paths to generated file contents. + Files map[string]string + + // Errors collects any errors encountered. A non-empty Errors slice means + // generation did not fully succeed. + Errors []error +} + +// Generate parses the sqlc configuration referenced by opts and runs every +// configured codegen target. The returned GenerateResult always has a non-nil +// Files map; the map is empty when generation fails before any files are +// produced. +func Generate(ctx context.Context, opts GenerateOptions) GenerateResult { + stderr := opts.Stderr + if stderr == nil { + stderr = io.Discard + } + + res := GenerateResult{Files: map[string]string{}} + + configPath, conf, err := readConfig(stderr, opts.Dir, opts.File) + if err != nil { + res.Errors = append(res.Errors, err) + return res + } + if opts.MutateConfig != nil { + opts.MutateConfig(conf) + } + + base := filepath.Base(configPath) + if err := config.Validate(conf); err != nil { + fmt.Fprintf(stderr, "error validating %s: %s\n", base, err) + res.Errors = append(res.Errors, err) + return res + } + + if opts.DisableProcessPlugins { + if err := validateProcessPluginsDisabled(conf); err != nil { + fmt.Fprintf(stderr, "error validating %s: %s\n", base, err) + res.Errors = append(res.Errors, err) + return res + } + } + + g := &generator{ + dir: opts.Dir, + output: map[string]string{}, + } + + if err := processQuerySets(ctx, g, conf, opts.Dir, stderr); err != nil { + res.Errors = append(res.Errors, err) + return res + } + + res.Files = g.output + return res +} + +func validateProcessPluginsDisabled(cfg *config.Config) error { + for _, plugin := range cfg.Plugins { + if plugin.Process != nil { + return errPluginProcessDisabled + } + } + return nil +} + +type generator struct { + m sync.Mutex + dir string + output map[string]string +} + +func (g *generator) Pairs(ctx context.Context, conf *config.Config) []outputPair { + var pairs []outputPair + for _, sql := range conf.SQL { + if sql.Gen.Go != nil { + pairs = append(pairs, outputPair{ + SQL: sql, + Gen: config.SQLGen{Go: sql.Gen.Go}, + }) + } + if sql.Gen.JSON != nil { + pairs = append(pairs, outputPair{ + SQL: sql, + Gen: config.SQLGen{JSON: sql.Gen.JSON}, + }) + } + for i := range sql.Codegen { + pairs = append(pairs, outputPair{ + SQL: sql, + Plugin: &sql.Codegen[i], + }) + } + } + return pairs +} + +func (g *generator) ProcessResult(ctx context.Context, combo config.CombinedSettings, sql outputPair, result *compiler.Result) error { + out, resp, err := codegen(ctx, combo, sql, result) + if err != nil { + return err + } + files := map[string]string{} + for _, file := range resp.Files { + files[file.Name] = string(file.Contents) + } + g.m.Lock() + + // out is specified by the user, not a plugin + absout := filepath.Join(g.dir, out) + + for n, source := range files { + filename := filepath.Join(g.dir, out, n) + // filepath.Join calls filepath.Clean which should remove all "..", but + // double check to make sure + if strings.Contains(filename, "..") { + return fmt.Errorf("invalid file output path: %s", filename) + } + // The output file must be contained inside the output directory + if !strings.HasPrefix(filename, absout) { + return fmt.Errorf("invalid file output path: %s", filename) + } + g.output[filename] = source + } + g.m.Unlock() + return nil +} diff --git a/internal/api/parse.go b/internal/api/parse.go new file mode 100644 index 0000000000..d2487eebea --- /dev/null +++ b/internal/api/parse.go @@ -0,0 +1,63 @@ +package api + +import ( + "context" + "fmt" + "io" + "path/filepath" + "runtime/trace" + + "github.com/sqlc-dev/sqlc/internal/compiler" + "github.com/sqlc-dev/sqlc/internal/config" + "github.com/sqlc-dev/sqlc/internal/debug" + "github.com/sqlc-dev/sqlc/internal/multierr" + "github.com/sqlc-dev/sqlc/internal/opts" +) + +func printFileErr(stderr io.Writer, dir string, fileErr *multierr.FileError) { + filename, err := filepath.Rel(dir, fileErr.Filename) + if err != nil { + filename = fileErr.Filename + } + fmt.Fprintf(stderr, "%s:%d:%d: %s\n", filename, fileErr.Line, fileErr.Column, fileErr.Err) +} + +func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.CombinedSettings, parserOpts opts.Parser, stderr io.Writer) (*compiler.Result, bool) { + defer trace.StartRegion(ctx, "parse").End() + c, err := compiler.NewCompiler(sql, combo, parserOpts) + defer func() { + if c != nil { + c.Close(ctx) + } + }() + if err != nil { + fmt.Fprintf(stderr, "error creating compiler: %s\n", err) + return nil, true + } + if err := c.ParseCatalog(sql.Schema); err != nil { + fmt.Fprintf(stderr, "# package %s\n", name) + if parserErr, ok := err.(*multierr.Error); ok { + for _, fileErr := range parserErr.Errs() { + printFileErr(stderr, dir, fileErr) + } + } else { + fmt.Fprintf(stderr, "error parsing schema: %s\n", err) + } + return nil, true + } + if parserOpts.Debug.DumpCatalog { + debug.Dump(c.Catalog()) + } + if err := c.ParseQueries(sql.Queries, parserOpts); err != nil { + fmt.Fprintf(stderr, "# package %s\n", name) + if parserErr, ok := err.(*multierr.Error); ok { + for _, fileErr := range parserErr.Errs() { + printFileErr(stderr, dir, fileErr) + } + } else { + fmt.Fprintf(stderr, "error parsing queries: %s\n", err) + } + return nil, true + } + return c.Result(), false +} diff --git a/internal/api/process.go b/internal/api/process.go new file mode 100644 index 0000000000..95d2c46e1e --- /dev/null +++ b/internal/api/process.go @@ -0,0 +1,109 @@ +package api + +import ( + "bytes" + "context" + "fmt" + "io" + "path/filepath" + "runtime" + "runtime/trace" + + "golang.org/x/sync/errgroup" + + "github.com/sqlc-dev/sqlc/internal/compiler" + "github.com/sqlc-dev/sqlc/internal/config" + "github.com/sqlc-dev/sqlc/internal/debug" + "github.com/sqlc-dev/sqlc/internal/opts" +) + +type outputPair struct { + Gen config.SQLGen + Plugin *config.Codegen + + config.SQL +} + +type resultProcessor interface { + Pairs(context.Context, *config.Config) []outputPair + ProcessResult(context.Context, config.CombinedSettings, outputPair, *compiler.Result) error +} + +func processQuerySets(ctx context.Context, rp resultProcessor, conf *config.Config, dir string, stderr io.Writer) error { + errored := false + + pairs := rp.Pairs(ctx, conf) + grp, gctx := errgroup.WithContext(ctx) + grp.SetLimit(runtime.GOMAXPROCS(0)) + + stderrs := make([]bytes.Buffer, len(pairs)) + + for i, pair := range pairs { + sql := pair + errout := &stderrs[i] + + grp.Go(func() error { + combo := config.Combine(*conf, sql.SQL) + if sql.Plugin != nil { + combo.Codegen = *sql.Plugin + } + + // TODO: This feels like a hack that will bite us later + joined := make([]string, 0, len(sql.Schema)) + for _, s := range sql.Schema { + joined = append(joined, filepath.Join(dir, s)) + } + sql.Schema = joined + + joined = make([]string, 0, len(sql.Queries)) + for _, q := range sql.Queries { + joined = append(joined, filepath.Join(dir, q)) + } + sql.Queries = joined + + var name, lang string + parseOpts := opts.Parser{ + Debug: debug.Debug, + } + + switch { + case sql.Gen.Go != nil: + name = combo.Go.Package + lang = "golang" + + case sql.Plugin != nil: + lang = fmt.Sprintf("process:%s", sql.Plugin.Plugin) + name = sql.Plugin.Plugin + } + + packageRegion := trace.StartRegion(gctx, "package") + trace.Logf(gctx, "", "name=%s dir=%s plugin=%s", name, dir, lang) + + result, failed := parse(gctx, name, dir, sql.SQL, combo, parseOpts, errout) + if failed { + packageRegion.End() + errored = true + return nil + } + if err := rp.ProcessResult(gctx, combo, sql, result); err != nil { + fmt.Fprintf(errout, "# package %s\n", name) + fmt.Fprintf(errout, "error generating code: %s\n", err) + errored = true + } + packageRegion.End() + return nil + }) + } + if err := grp.Wait(); err != nil { + return err + } + if errored { + for i := range stderrs { + if _, err := io.Copy(stderr, &stderrs[i]); err != nil { + return err + } + } + return fmt.Errorf("errored") + } + return nil +} diff --git a/internal/api/shim.go b/internal/api/shim.go new file mode 100644 index 0000000000..9638d1f126 --- /dev/null +++ b/internal/api/shim.go @@ -0,0 +1,233 @@ +package api + +import ( + "github.com/sqlc-dev/sqlc/internal/compiler" + "github.com/sqlc-dev/sqlc/internal/config" + "github.com/sqlc-dev/sqlc/internal/config/convert" + "github.com/sqlc-dev/sqlc/internal/info" + "github.com/sqlc-dev/sqlc/internal/plugin" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" +) + +func pluginSettings(r *compiler.Result, cs config.CombinedSettings) *plugin.Settings { + return &plugin.Settings{ + Version: cs.Global.Version, + Engine: string(cs.Package.Engine), + Schema: []string(cs.Package.Schema), + Queries: []string(cs.Package.Queries), + Codegen: pluginCodegen(cs, cs.Codegen), + } +} + +func pluginCodegen(cs config.CombinedSettings, s config.Codegen) *plugin.Codegen { + opts, err := convert.YAMLtoJSON(s.Options) + if err != nil { + panic(err) + } + cg := &plugin.Codegen{ + Out: s.Out, + Plugin: s.Plugin, + Options: opts, + } + for _, p := range cs.Global.Plugins { + if p.Name == s.Plugin { + cg.Env = p.Env + cg.Process = pluginProcess(p) + cg.Wasm = pluginWASM(p) + return cg + } + } + return cg +} + +func pluginProcess(p config.Plugin) *plugin.Codegen_Process { + if p.Process != nil { + return &plugin.Codegen_Process{ + Cmd: p.Process.Cmd, + } + } + return nil +} + +func pluginWASM(p config.Plugin) *plugin.Codegen_WASM { + if p.WASM != nil { + return &plugin.Codegen_WASM{ + Url: p.WASM.URL, + Sha256: p.WASM.SHA256, + } + } + return nil +} + +func pluginCatalog(c *catalog.Catalog) *plugin.Catalog { + var schemas []*plugin.Schema + for _, s := range c.Schemas { + var enums []*plugin.Enum + var cts []*plugin.CompositeType + for _, typ := range s.Types { + switch typ := typ.(type) { + case *catalog.Enum: + enums = append(enums, &plugin.Enum{ + Name: typ.Name, + Comment: typ.Comment, + Vals: typ.Vals, + }) + case *catalog.CompositeType: + cts = append(cts, &plugin.CompositeType{ + Name: typ.Name, + Comment: typ.Comment, + }) + } + } + var tables []*plugin.Table + for _, t := range s.Tables { + var columns []*plugin.Column + for _, c := range t.Columns { + l := -1 + if c.Length != nil { + l = *c.Length + } + columns = append(columns, &plugin.Column{ + Name: c.Name, + Type: &plugin.Identifier{ + Catalog: c.Type.Catalog, + Schema: c.Type.Schema, + Name: c.Type.Name, + }, + Comment: c.Comment, + NotNull: c.IsNotNull, + Unsigned: c.IsUnsigned, + IsArray: c.IsArray, + ArrayDims: int32(c.ArrayDims), + Length: int32(l), + Table: &plugin.Identifier{ + Catalog: t.Rel.Catalog, + Schema: t.Rel.Schema, + Name: t.Rel.Name, + }, + }) + } + tables = append(tables, &plugin.Table{ + Rel: &plugin.Identifier{ + Catalog: t.Rel.Catalog, + Schema: t.Rel.Schema, + Name: t.Rel.Name, + }, + Columns: columns, + Comment: t.Comment, + }) + } + schemas = append(schemas, &plugin.Schema{ + Comment: s.Comment, + Name: s.Name, + Tables: tables, + Enums: enums, + CompositeTypes: cts, + }) + } + return &plugin.Catalog{ + Name: c.Name, + DefaultSchema: c.DefaultSchema, + Comment: c.Comment, + Schemas: schemas, + } +} + +func pluginQueries(r *compiler.Result) []*plugin.Query { + var out []*plugin.Query + for _, q := range r.Queries { + var params []*plugin.Parameter + var columns []*plugin.Column + for _, c := range q.Columns { + columns = append(columns, pluginQueryColumn(c)) + } + for _, p := range q.Params { + params = append(params, pluginQueryParam(p)) + } + var iit *plugin.Identifier + if q.InsertIntoTable != nil { + iit = &plugin.Identifier{ + Catalog: q.InsertIntoTable.Catalog, + Schema: q.InsertIntoTable.Schema, + Name: q.InsertIntoTable.Name, + } + } + out = append(out, &plugin.Query{ + Name: q.Metadata.Name, + Cmd: q.Metadata.Cmd, + Text: q.SQL, + Comments: q.Metadata.Comments, + Columns: columns, + Params: params, + Filename: q.Metadata.Filename, + InsertIntoTable: iit, + }) + } + return out +} + +func pluginQueryColumn(c *compiler.Column) *plugin.Column { + l := -1 + if c.Length != nil { + l = *c.Length + } + out := &plugin.Column{ + Name: c.Name, + OriginalName: c.OriginalName, + Comment: c.Comment, + NotNull: c.NotNull, + Unsigned: c.Unsigned, + IsArray: c.IsArray, + ArrayDims: int32(c.ArrayDims), + Length: int32(l), + IsNamedParam: c.IsNamedParam, + IsFuncCall: c.IsFuncCall, + IsSqlcSlice: c.IsSqlcSlice, + } + + if c.Type != nil { + out.Type = &plugin.Identifier{ + Catalog: c.Type.Catalog, + Schema: c.Type.Schema, + Name: c.Type.Name, + } + } else { + out.Type = &plugin.Identifier{ + Name: c.DataType, + } + } + + if c.Table != nil { + out.Table = &plugin.Identifier{ + Catalog: c.Table.Catalog, + Schema: c.Table.Schema, + Name: c.Table.Name, + } + } + + if c.EmbedTable != nil { + out.EmbedTable = &plugin.Identifier{ + Catalog: c.EmbedTable.Catalog, + Schema: c.EmbedTable.Schema, + Name: c.EmbedTable.Name, + } + } + + return out +} + +func pluginQueryParam(p compiler.Parameter) *plugin.Parameter { + return &plugin.Parameter{ + Number: int32(p.Number), + Column: pluginQueryColumn(p.Column), + } +} + +func codeGenRequest(r *compiler.Result, settings config.CombinedSettings) *plugin.GenerateRequest { + return &plugin.GenerateRequest{ + Settings: pluginSettings(r, settings), + Catalog: pluginCatalog(r.Catalog), + Queries: pluginQueries(r), + SqlcVersion: info.Version, + } +} diff --git a/internal/cmd/generate.go b/internal/cmd/generate.go index ca3ee680b5..ccbc2c2169 100644 --- a/internal/cmd/generate.go +++ b/internal/cmd/generate.go @@ -2,30 +2,19 @@ package cmd import ( "context" - "encoding/json" "errors" "fmt" "io" "os" "path/filepath" "runtime/trace" - "strings" - "sync" - "google.golang.org/grpc" - - "github.com/sqlc-dev/sqlc/internal/codegen/golang" - genjson "github.com/sqlc-dev/sqlc/internal/codegen/json" + "github.com/sqlc-dev/sqlc/internal/api" "github.com/sqlc-dev/sqlc/internal/compiler" "github.com/sqlc-dev/sqlc/internal/config" - "github.com/sqlc-dev/sqlc/internal/config/convert" "github.com/sqlc-dev/sqlc/internal/debug" - "github.com/sqlc-dev/sqlc/internal/ext" - "github.com/sqlc-dev/sqlc/internal/ext/process" - "github.com/sqlc-dev/sqlc/internal/ext/wasm" "github.com/sqlc-dev/sqlc/internal/multierr" "github.com/sqlc-dev/sqlc/internal/opts" - "github.com/sqlc-dev/sqlc/internal/plugin" ) const errMessageNoVersion = `The configuration file must have a version number. @@ -51,15 +40,6 @@ func printFileErr(stderr io.Writer, dir string, fileErr *multierr.FileError) { fmt.Fprintf(stderr, "%s:%d:%d: %s\n", filename, fileErr.Line, fileErr.Column, fileErr.Err) } -func findPlugin(conf config.Config, name string) (*config.Plugin, error) { - for _, plug := range conf.Plugins { - if plug.Name == name { - return &plug, nil - } - } - return nil, fmt.Errorf("plugin not found") -} - func readConfig(stderr io.Writer, dir, filename string) (string, *config.Config, error) { configPath := "" if filename != "" { @@ -127,98 +107,21 @@ func readConfig(stderr io.Writer, dir, filename string) (string, *config.Config, return configPath, &conf, nil } +// Generate is a thin wrapper around api.Generate that translates between the +// CLI's Options struct and api.GenerateOptions. New callers should prefer +// api.Generate directly. func Generate(ctx context.Context, dir, filename string, o *Options) (map[string]string, error) { - e := o.Env - stderr := o.Stderr - - configPath, conf, err := o.ReadConfig(dir, filename) - if err != nil { - return nil, err - } - - base := filepath.Base(configPath) - if err := config.Validate(conf); err != nil { - fmt.Fprintf(stderr, "error validating %s: %s\n", base, err) - return nil, err - } - - if err := e.Validate(conf); err != nil { - fmt.Fprintf(stderr, "error validating %s: %s\n", base, err) - return nil, err - } - - g := &generator{ - dir: dir, - output: map[string]string{}, - } - - if err := processQuerySets(ctx, g, conf, dir, o); err != nil { - return nil, err - } - - return g.output, nil -} - -type generator struct { - m sync.Mutex - dir string - output map[string]string -} - -func (g *generator) Pairs(ctx context.Context, conf *config.Config) []OutputPair { - var pairs []OutputPair - for _, sql := range conf.SQL { - if sql.Gen.Go != nil { - pairs = append(pairs, OutputPair{ - SQL: sql, - Gen: config.SQLGen{Go: sql.Gen.Go}, - }) - } - if sql.Gen.JSON != nil { - pairs = append(pairs, OutputPair{ - SQL: sql, - Gen: config.SQLGen{JSON: sql.Gen.JSON}, - }) - } - for i := range sql.Codegen { - pairs = append(pairs, OutputPair{ - SQL: sql, - Plugin: &sql.Codegen[i], - }) - } - } - return pairs -} - -func (g *generator) ProcessResult(ctx context.Context, combo config.CombinedSettings, sql OutputPair, result *compiler.Result) error { - out, resp, err := codegen(ctx, combo, sql, result) - if err != nil { - return err - } - files := map[string]string{} - for _, file := range resp.Files { - files[file.Name] = string(file.Contents) - } - g.m.Lock() - - // out is specified by the user, not a plugin - absout := filepath.Join(g.dir, out) - - for n, source := range files { - filename := filepath.Join(g.dir, out, n) - // filepath.Join calls filepath.Clean which should remove all "..", but - // double check to make sure - if strings.Contains(filename, "..") { - return fmt.Errorf("invalid file output path: %s", filename) - } - // The output file must be contained inside the output directory - if !strings.HasPrefix(filename, absout) { - return fmt.Errorf("invalid file output path: %s", filename) - } - g.output[filename] = source - } - g.m.Unlock() - return nil + res := api.Generate(ctx, api.GenerateOptions{ + Dir: dir, + File: filename, + Stderr: o.Stderr, + DisableProcessPlugins: !o.Env.Debug.ProcessPlugins, + MutateConfig: o.MutateConfig, + }) + if len(res.Errors) > 0 { + return res.Files, res.Errors[0] + } + return res.Files, nil } func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.CombinedSettings, parserOpts opts.Parser, stderr io.Writer) (*compiler.Result, bool) { @@ -260,82 +163,3 @@ func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.C } return c.Result(), false } - -func codegen(ctx context.Context, combo config.CombinedSettings, sql OutputPair, result *compiler.Result) (string, *plugin.GenerateResponse, error) { - defer trace.StartRegion(ctx, "codegen").End() - req := codeGenRequest(result, combo) - var handler grpc.ClientConnInterface - var out string - switch { - case sql.Plugin != nil: - out = sql.Plugin.Out - plug, err := findPlugin(combo.Global, sql.Plugin.Plugin) - if err != nil { - return "", nil, fmt.Errorf("plugin not found: %s", err) - } - - switch { - case plug.Process != nil: - handler = &process.Runner{ - Cmd: plug.Process.Cmd, - Env: plug.Env, - Format: plug.Process.Format, - } - case plug.WASM != nil: - handler = &wasm.Runner{ - URL: plug.WASM.URL, - SHA256: plug.WASM.SHA256, - Env: plug.Env, - } - default: - return "", nil, fmt.Errorf("unsupported plugin type") - } - - opts, err := convert.YAMLtoJSON(sql.Plugin.Options) - if err != nil { - return "", nil, fmt.Errorf("invalid plugin options: %w", err) - } - req.PluginOptions = opts - - global, found := combo.Global.Options[plug.Name] - if found { - opts, err := convert.YAMLtoJSON(global) - if err != nil { - return "", nil, fmt.Errorf("invalid global options: %w", err) - } - req.GlobalOptions = opts - } - - case sql.Gen.Go != nil: - out = combo.Go.Out - handler = ext.HandleFunc(golang.Generate) - opts, err := json.Marshal(sql.Gen.Go) - if err != nil { - return "", nil, fmt.Errorf("opts marshal failed: %w", err) - } - req.PluginOptions = opts - - if combo.Global.Overrides.Go != nil { - opts, err := json.Marshal(combo.Global.Overrides.Go) - if err != nil { - return "", nil, fmt.Errorf("opts marshal failed: %w", err) - } - req.GlobalOptions = opts - } - - case sql.Gen.JSON != nil: - out = combo.JSON.Out - handler = ext.HandleFunc(genjson.Generate) - opts, err := json.Marshal(sql.Gen.JSON) - if err != nil { - return "", nil, fmt.Errorf("opts marshal failed: %w", err) - } - req.PluginOptions = opts - - default: - return "", nil, fmt.Errorf("missing language backend") - } - client := plugin.NewCodegenServiceClient(handler) - resp, err := client.Generate(ctx, req) - return out, resp, err -} diff --git a/internal/endtoend/endtoend_test.go b/internal/endtoend/endtoend_test.go index 91e44ff7f0..c17c2a4fda 100644 --- a/internal/endtoend/endtoend_test.go +++ b/internal/endtoend/endtoend_test.go @@ -14,6 +14,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/sqlc-dev/sqlc/internal/api" "github.com/sqlc-dev/sqlc/internal/cmd" "github.com/sqlc-dev/sqlc/internal/config" "github.com/sqlc-dev/sqlc/internal/opts" @@ -58,15 +59,14 @@ func TestExamples(t *testing.T) { t.Parallel() path := filepath.Join(examples, tc) var stderr bytes.Buffer - opts := &cmd.Options{ - Env: cmd.Env{}, + res := api.Generate(ctx, api.GenerateOptions{ + Dir: path, Stderr: &stderr, - } - output, err := cmd.Generate(ctx, path, "", opts) - if err != nil { + }) + if len(res.Errors) > 0 { t.Fatalf("sqlc generate failed: %s", stderr.String()) } - cmpDirectory(t, path, output) + cmpDirectory(t, path, res.Files) }) } } @@ -90,11 +90,10 @@ func BenchmarkExamples(b *testing.B) { path := filepath.Join(examples, tc) for i := 0; i < b.N; i++ { var stderr bytes.Buffer - opts := &cmd.Options{ - Env: cmd.Env{}, + api.Generate(ctx, api.GenerateOptions{ + Dir: path, Stderr: &stderr, - } - cmd.Generate(ctx, path, "", opts) + }) } }) } @@ -261,9 +260,10 @@ func TestReplay(t *testing.T) { } } - opts := cmd.Options{ + dbg := opts.DebugFromString(args.Env["SQLCDEBUG"]) + cmdOpts := cmd.Options{ Env: cmd.Env{ - Debug: opts.DebugFromString(args.Env["SQLCDEBUG"]), + Debug: dbg, Experiment: opts.ExperimentFromString(args.Env["SQLCEXPERIMENT"]), }, Stderr: &stderr, @@ -272,14 +272,23 @@ func TestReplay(t *testing.T) { switch args.Command { case "diff": - err = cmd.Diff(ctx, path, "", &opts) + err = cmd.Diff(ctx, path, "", &cmdOpts) case "generate": - output, err = cmd.Generate(ctx, path, "", &opts) + res := api.Generate(ctx, api.GenerateOptions{ + Dir: path, + Stderr: &stderr, + DisableProcessPlugins: !dbg.ProcessPlugins, + MutateConfig: testctx.Mutate(t, path), + }) + output = res.Files + if len(res.Errors) > 0 { + err = res.Errors[0] + } if err == nil { cmpDirectory(t, path, output) } case "vet": - err = cmd.Vet(ctx, path, "", &opts) + err = cmd.Vet(ctx, path, "", &cmdOpts) default: t.Fatalf("unknown command") } @@ -387,11 +396,10 @@ func BenchmarkReplay(b *testing.B) { path, _ := filepath.Abs(tc) for i := 0; i < b.N; i++ { var stderr bytes.Buffer - opts := &cmd.Options{ - Env: cmd.Env{}, + api.Generate(ctx, api.GenerateOptions{ + Dir: path, Stderr: &stderr, - } - cmd.Generate(ctx, path, "", opts) + }) } }) } From 9848cd6b9fae61f68ac49d121159b31cc065d418 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 28 Apr 2026 06:23:36 +0000 Subject: [PATCH 2/4] Add api.GenerateOptions.Write and api.GenerateOptions.Diff MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The two new boolean options let api.Generate cover the writefiles loop and the diff comparison that previously lived in cmd. The compile command becomes Generate with neither flag set, generate maps to Write: true, and diff maps to Diff: true. While simplifying GenerateOptions: * Drop MutateConfig — tests now express config mutations by writing a temporary configuration file via writeMutatedConfig and pointing GenerateOptions.File at it. The mutated config is parsed (always to v2 shape), forced to version "2", and round-tripped via yaml. * Drop DisableProcessPlugins from the API surface; we will revisit how to express that constraint. * Add MarshalJSON/MarshalYAML to AnalyzerDatabase so the parsed Config round-trips through yaml.Marshal cleanly, which is what the new test helper relies on. cmd/diff.go is gone and cmd/generate.go is left with only the helpers (readConfig, parse, printFileErr) other cmd commands still use. https://claude.ai/code/session_01RCzB2JR5Y5ScFDUmwcxGVZ --- internal/api/diff.go | 100 ++++++++++++++++++++++ internal/api/generate.go | 48 ++++------- internal/cmd/cmd.go | 66 ++++---------- internal/cmd/diff.go | 59 ------------- internal/cmd/generate.go | 18 ---- internal/cmd/options.go | 12 +-- internal/config/config.go | 24 ++++++ internal/endtoend/endtoend_test.go | 133 ++++++++++++++++++++--------- 8 files changed, 252 insertions(+), 208 deletions(-) create mode 100644 internal/api/diff.go delete mode 100644 internal/cmd/diff.go diff --git a/internal/api/diff.go b/internal/api/diff.go new file mode 100644 index 0000000000..9c07fd1456 --- /dev/null +++ b/internal/api/diff.go @@ -0,0 +1,100 @@ +package api + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "runtime/trace" + "sort" + "strings" + + "github.com/cubicdaiya/gonp" +) + +func writeFiles(ctx context.Context, files map[string]string, stderr io.Writer) error { + defer trace.StartRegion(ctx, "writefiles").End() + for filename, source := range files { + if err := os.MkdirAll(filepath.Dir(filename), 0755); err != nil { + fmt.Fprintf(stderr, "%s: %s\n", filename, err) + return err + } + if err := os.WriteFile(filename, []byte(source), 0644); err != nil { + fmt.Fprintf(stderr, "%s: %s\n", filename, err) + return err + } + } + return nil +} + +func diffFiles(ctx context.Context, dir string, files map[string]string, stderr io.Writer) error { + defer trace.StartRegion(ctx, "checkfiles").End() + var errored bool + + keys := make([]string, 0, len(files)) + for k := range files { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, filename := range keys { + source := files[filename] + if _, err := os.Stat(filename); errors.Is(err, os.ErrNotExist) { + errored = true + continue + } + existing, err := os.ReadFile(filename) + if err != nil { + errored = true + fmt.Fprintf(stderr, "%s: %s\n", filename, err) + continue + } + d := gonp.New(getLines(existing), getLines([]byte(source))) + d.Compose() + uniHunks := filterHunks(d.UnifiedHunks()) + + if len(uniHunks) > 0 { + errored = true + fmt.Fprintf(stderr, "--- a%s\n", strings.TrimPrefix(filename, dir)) + fmt.Fprintf(stderr, "+++ b%s\n", strings.TrimPrefix(filename, dir)) + d.FprintUniHunks(stderr, uniHunks) + } + } + if errored { + return errors.New("diff found") + } + return nil +} + +func getLines(f []byte) []string { + fp := bytes.NewReader(f) + scanner := bufio.NewScanner(fp) + lines := make([]string, 0) + for scanner.Scan() { + lines = append(lines, scanner.Text()) + } + return lines +} + +func filterHunks[T gonp.Elem](uniHunks []gonp.UniHunk[T]) []gonp.UniHunk[T] { + var out []gonp.UniHunk[T] + for i, uniHunk := range uniHunks { + var changed bool + for _, e := range uniHunk.GetChanges() { + switch e.GetType() { + case gonp.SesDelete: + changed = true + case gonp.SesAdd: + changed = true + } + } + if changed { + out = append(out, uniHunks[i]) + } + } + return out +} diff --git a/internal/api/generate.go b/internal/api/generate.go index 5d10f015d1..7321e08445 100644 --- a/internal/api/generate.go +++ b/internal/api/generate.go @@ -2,7 +2,6 @@ package api import ( "context" - "errors" "fmt" "io" "path/filepath" @@ -13,10 +12,6 @@ import ( "github.com/sqlc-dev/sqlc/internal/config" ) -// errPluginProcessDisabled is returned when the configuration uses a process -// plugin but the caller has disabled them via GenerateOptions.DisableProcessPlugins. -var errPluginProcessDisabled = errors.New("plugin: process-based plugins disabled via SQLCDEBUG=processplugins=0") - // GenerateOptions controls a single Generate invocation. type GenerateOptions struct { // Dir is the working directory used to resolve the config file and any @@ -30,14 +25,14 @@ type GenerateOptions struct { // Stderr receives diagnostic output. If nil, output is discarded. Stderr io.Writer - // DisableProcessPlugins, when true, causes Generate to fail if the - // configuration uses a process-based plugin. The sqlc CLI sets this from - // SQLCDEBUG=processplugins=0. - DisableProcessPlugins bool + // Write, when true, writes the generated files to disk after a successful + // generate. Failures are reported via GenerateResult.Errors. + Write bool - // MutateConfig is called after the configuration is parsed but before it is - // validated. It is intended for tests. - MutateConfig func(*config.Config) + // Diff, when true, compares each generated file against any existing file + // on disk and writes a unified diff for differences to Stderr. If any + // differences are found, an error is appended to GenerateResult.Errors. + Diff bool } // GenerateResult is the outcome of a Generate call. Files maps absolute output @@ -69,9 +64,6 @@ func Generate(ctx context.Context, opts GenerateOptions) GenerateResult { res.Errors = append(res.Errors, err) return res } - if opts.MutateConfig != nil { - opts.MutateConfig(conf) - } base := filepath.Base(configPath) if err := config.Validate(conf); err != nil { @@ -80,14 +72,6 @@ func Generate(ctx context.Context, opts GenerateOptions) GenerateResult { return res } - if opts.DisableProcessPlugins { - if err := validateProcessPluginsDisabled(conf); err != nil { - fmt.Fprintf(stderr, "error validating %s: %s\n", base, err) - res.Errors = append(res.Errors, err) - return res - } - } - g := &generator{ dir: opts.Dir, output: map[string]string{}, @@ -99,16 +83,20 @@ func Generate(ctx context.Context, opts GenerateOptions) GenerateResult { } res.Files = g.output - return res -} -func validateProcessPluginsDisabled(cfg *config.Config) error { - for _, plugin := range cfg.Plugins { - if plugin.Process != nil { - return errPluginProcessDisabled + if opts.Write { + if err := writeFiles(ctx, res.Files, stderr); err != nil { + res.Errors = append(res.Errors, err) } } - return nil + + if opts.Diff { + if err := diffFiles(ctx, opts.Dir, res.Files, stderr); err != nil { + res.Errors = append(res.Errors, err) + } + } + + return res } type generator struct { diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go index f9c09dfe06..6e5ff72bc7 100644 --- a/internal/cmd/cmd.go +++ b/internal/cmd/cmd.go @@ -1,8 +1,6 @@ package cmd import ( - "bufio" - "bytes" "context" "errors" "fmt" @@ -12,11 +10,11 @@ import ( "path/filepath" "runtime/trace" - "github.com/cubicdaiya/gonp" "github.com/spf13/cobra" "github.com/spf13/pflag" "gopkg.in/yaml.v3" + "github.com/sqlc-dev/sqlc/internal/api" "github.com/sqlc-dev/sqlc/internal/config" "github.com/sqlc-dev/sqlc/internal/debug" "github.com/sqlc-dev/sqlc/internal/info" @@ -191,21 +189,15 @@ var genCmd = &cobra.Command{ defer trace.StartRegion(cmd.Context(), "generate").End() stderr := cmd.ErrOrStderr() dir, name := getConfigPath(stderr, cmd.Flag("file")) - output, err := Generate(cmd.Context(), dir, name, &Options{ - Env: ParseEnv(cmd), + res := api.Generate(cmd.Context(), api.GenerateOptions{ + Dir: dir, + File: name, Stderr: stderr, + Write: true, }) - if err != nil { + if len(res.Errors) > 0 { os.Exit(1) } - defer trace.StartRegion(cmd.Context(), "writefiles").End() - for filename, source := range output { - os.MkdirAll(filepath.Dir(filename), 0755) - if err := os.WriteFile(filename, []byte(source), 0644); err != nil { - fmt.Fprintf(stderr, "%s: %s\n", filename, err) - return err - } - } return nil }, } @@ -217,46 +209,18 @@ var checkCmd = &cobra.Command{ defer trace.StartRegion(cmd.Context(), "compile").End() stderr := cmd.ErrOrStderr() dir, name := getConfigPath(stderr, cmd.Flag("file")) - _, err := Generate(cmd.Context(), dir, name, &Options{ - Env: ParseEnv(cmd), + res := api.Generate(cmd.Context(), api.GenerateOptions{ + Dir: dir, + File: name, Stderr: stderr, }) - if err != nil { + if len(res.Errors) > 0 { os.Exit(1) } return nil }, } -func getLines(f []byte) []string { - fp := bytes.NewReader(f) - scanner := bufio.NewScanner(fp) - lines := make([]string, 0) - for scanner.Scan() { - lines = append(lines, scanner.Text()) - } - return lines -} - -func filterHunks[T gonp.Elem](uniHunks []gonp.UniHunk[T]) []gonp.UniHunk[T] { - var out []gonp.UniHunk[T] - for i, uniHunk := range uniHunks { - var changed bool - for _, e := range uniHunk.GetChanges() { - switch e.GetType() { - case gonp.SesDelete: - changed = true - case gonp.SesAdd: - changed = true - } - } - if changed { - out = append(out, uniHunks[i]) - } - } - return out -} - var diffCmd = &cobra.Command{ Use: "diff", Short: "Compare the generated files to the existing files", @@ -264,11 +228,13 @@ var diffCmd = &cobra.Command{ defer trace.StartRegion(cmd.Context(), "diff").End() stderr := cmd.ErrOrStderr() dir, name := getConfigPath(stderr, cmd.Flag("file")) - opts := &Options{ - Env: ParseEnv(cmd), + res := api.Generate(cmd.Context(), api.GenerateOptions{ + Dir: dir, + File: name, Stderr: stderr, - } - if err := Diff(cmd.Context(), dir, name, opts); err != nil { + Diff: true, + }) + if len(res.Errors) > 0 { os.Exit(1) } return nil diff --git a/internal/cmd/diff.go b/internal/cmd/diff.go deleted file mode 100644 index 8998971a37..0000000000 --- a/internal/cmd/diff.go +++ /dev/null @@ -1,59 +0,0 @@ -package cmd - -import ( - "context" - "errors" - "fmt" - "os" - "runtime/trace" - "sort" - "strings" - - "github.com/cubicdaiya/gonp" -) - -func Diff(ctx context.Context, dir, name string, opts *Options) error { - stderr := opts.Stderr - output, err := Generate(ctx, dir, name, opts) - if err != nil { - return err - } - defer trace.StartRegion(ctx, "checkfiles").End() - var errored bool - - keys := make([]string, 0, len(output)) - for k, _ := range output { - kk := k - keys = append(keys, kk) - } - sort.Strings(keys) - - for _, filename := range keys { - source := output[filename] - if _, err := os.Stat(filename); errors.Is(err, os.ErrNotExist) { - errored = true - // stdout message - continue - } - existing, err := os.ReadFile(filename) - if err != nil { - errored = true - fmt.Fprintf(stderr, "%s: %s\n", filename, err) - continue - } - diff := gonp.New(getLines(existing), getLines([]byte(source))) - diff.Compose() - uniHunks := filterHunks(diff.UnifiedHunks()) - - if len(uniHunks) > 0 { - errored = true - fmt.Fprintf(stderr, "--- a%s\n", strings.TrimPrefix(filename, dir)) - fmt.Fprintf(stderr, "+++ b%s\n", strings.TrimPrefix(filename, dir)) - diff.FprintUniHunks(stderr, uniHunks) - } - } - if errored { - return errors.New("diff found") - } - return nil -} diff --git a/internal/cmd/generate.go b/internal/cmd/generate.go index ccbc2c2169..1785a40718 100644 --- a/internal/cmd/generate.go +++ b/internal/cmd/generate.go @@ -9,7 +9,6 @@ import ( "path/filepath" "runtime/trace" - "github.com/sqlc-dev/sqlc/internal/api" "github.com/sqlc-dev/sqlc/internal/compiler" "github.com/sqlc-dev/sqlc/internal/config" "github.com/sqlc-dev/sqlc/internal/debug" @@ -107,23 +106,6 @@ func readConfig(stderr io.Writer, dir, filename string) (string, *config.Config, return configPath, &conf, nil } -// Generate is a thin wrapper around api.Generate that translates between the -// CLI's Options struct and api.GenerateOptions. New callers should prefer -// api.Generate directly. -func Generate(ctx context.Context, dir, filename string, o *Options) (map[string]string, error) { - res := api.Generate(ctx, api.GenerateOptions{ - Dir: dir, - File: filename, - Stderr: o.Stderr, - DisableProcessPlugins: !o.Env.Debug.ProcessPlugins, - MutateConfig: o.MutateConfig, - }) - if len(res.Errors) > 0 { - return res.Files, res.Errors[0] - } - return res.Files, nil -} - func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.CombinedSettings, parserOpts opts.Parser, stderr io.Writer) (*compiler.Result, bool) { defer trace.StartRegion(ctx, "parse").End() c, err := compiler.NewCompiler(sql, combo, parserOpts) diff --git a/internal/cmd/options.go b/internal/cmd/options.go index 02d3614f4e..7a1fab33f0 100644 --- a/internal/cmd/options.go +++ b/internal/cmd/options.go @@ -12,18 +12,8 @@ type Options struct { // TODO: Move these to a command-specific struct Tags []string Against string - - // Testing only - MutateConfig func(*config.Config) } func (o *Options) ReadConfig(dir, filename string) (string, *config.Config, error) { - path, conf, err := readConfig(o.Stderr, dir, filename) - if err != nil { - return path, conf, err - } - if o.MutateConfig != nil { - o.MutateConfig(conf) - } - return path, conf, nil + return readConfig(o.Stderr, dir, filename) } diff --git a/internal/config/config.go b/internal/config/config.go index d3e610ef05..19bcda754a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -189,6 +189,30 @@ func (a *AnalyzerDatabase) UnmarshalYAML(unmarshal func(interface{}) error) erro return errors.New("analyzer.database must be true, false, or \"only\"") } +// MarshalJSON serialises the value back to its source form so configs +// produced via UnmarshalJSON round-trip through json.Marshal. +func (a AnalyzerDatabase) MarshalJSON() ([]byte, error) { + if a.isOnly { + return json.Marshal("only") + } + if a.value == nil { + return json.Marshal(true) + } + return json.Marshal(*a.value) +} + +// MarshalYAML serialises the value back to its source form so configs +// produced via UnmarshalYAML round-trip through yaml.Marshal. +func (a AnalyzerDatabase) MarshalYAML() (interface{}, error) { + if a.isOnly { + return "only", nil + } + if a.value == nil { + return true, nil + } + return *a.value, nil +} + type Analyzer struct { Database AnalyzerDatabase `json:"database" yaml:"database"` } diff --git a/internal/endtoend/endtoend_test.go b/internal/endtoend/endtoend_test.go index c17c2a4fda..68cc037225 100644 --- a/internal/endtoend/endtoend_test.go +++ b/internal/endtoend/endtoend_test.go @@ -3,6 +3,7 @@ package main import ( "bytes" "context" + "fmt" "os" osexec "os/exec" "path/filepath" @@ -13,6 +14,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "gopkg.in/yaml.v3" "github.com/sqlc-dev/sqlc/internal/api" "github.com/sqlc-dev/sqlc/internal/cmd" @@ -99,11 +101,73 @@ func BenchmarkExamples(b *testing.B) { } } +// textContext describes a TestReplay scenario. Mutate returns the config +// filename (relative to the test directory) that should be passed to the +// command under test. The "base" context returns "" to use the project's +// existing sqlc config; the "managed-db" context writes a mutated copy of the +// config to a temporary file inside the test directory and returns its name. type textContext struct { - Mutate func(*testing.T, string) func(*config.Config) + Mutate func(*testing.T, string) string Enabled func() bool } +// writeMutatedConfig parses the sqlc config in dir, applies mutate to the +// in-memory Config (which is always v2-shaped, even when the file on disk is +// v1), forces version "2", and writes the result to a temp file alongside the +// original. The temp file is removed when the test ends. +func writeMutatedConfig(t *testing.T, dir string, mutate func(*config.Config)) string { + t.Helper() + original, conf, err := readSqlcConfig(dir) + if err != nil { + t.Fatalf("read sqlc config from %s: %s", dir, err) + } + + // Parsing v1 configs converts them to a v2-shaped Config. Force version "2" + // so the mutated config can be re-parsed as v2 from disk. + conf.Version = "2" + mutate(conf) + + f, err := os.CreateTemp(dir, "sqlc.test-*"+filepath.Ext(original)) + if err != nil { + t.Fatalf("create temp config in %s: %s", dir, err) + } + t.Cleanup(func() { os.Remove(f.Name()) }) + + enc := yaml.NewEncoder(f) + if err := enc.Encode(conf); err != nil { + f.Close() + t.Fatalf("write temp config %s: %s", f.Name(), err) + } + if err := enc.Close(); err != nil { + f.Close() + t.Fatalf("close yaml encoder for %s: %s", f.Name(), err) + } + if err := f.Close(); err != nil { + t.Fatalf("close temp config %s: %s", f.Name(), err) + } + return filepath.Base(f.Name()) +} + +func readSqlcConfig(dir string) (string, *config.Config, error) { + for _, name := range []string{"sqlc.yaml", "sqlc.yml", "sqlc.json"} { + path := filepath.Join(dir, name) + f, err := os.Open(path) + if err != nil { + if os.IsNotExist(err) { + continue + } + return path, nil, err + } + defer f.Close() + conf, err := config.ParseConfig(f) + if err != nil { + return path, nil, fmt.Errorf("parse %s: %w", path, err) + } + return path, &conf, nil + } + return "", nil, fmt.Errorf("no sqlc config found in %s", dir) +} + func TestReplay(t *testing.T) { // Ensure that this environment variable is always set to true when running // end-to-end tests @@ -172,45 +236,24 @@ func TestReplay(t *testing.T) { contexts := map[string]textContext{ "base": { - Mutate: func(t *testing.T, path string) func(*config.Config) { return func(c *config.Config) {} }, + Mutate: func(t *testing.T, path string) string { return "" }, Enabled: func() bool { return true }, }, "managed-db": { - Mutate: func(t *testing.T, path string) func(*config.Config) { - return func(c *config.Config) { + Mutate: func(t *testing.T, path string) string { + return writeMutatedConfig(t, path, func(c *config.Config) { // Add all servers - tests will fail if database isn't available c.Servers = []config.Server{ - { - Name: "postgres", - Engine: config.EnginePostgreSQL, - URI: postgresURI, - }, - { - Name: "mysql", - Engine: config.EngineMySQL, - URI: mysqlURI, - }, + {Name: "postgres", Engine: config.EnginePostgreSQL, URI: postgresURI}, + {Name: "mysql", Engine: config.EngineMySQL, URI: mysqlURI}, } - for i := range c.SQL { switch c.SQL[i].Engine { - case config.EnginePostgreSQL: - c.SQL[i].Database = &config.Database{ - Managed: true, - } - case config.EngineMySQL: - c.SQL[i].Database = &config.Database{ - Managed: true, - } - case config.EngineSQLite: - c.SQL[i].Database = &config.Database{ - Managed: true, - } - default: - // pass + case config.EnginePostgreSQL, config.EngineMySQL, config.EngineSQLite: + c.SQL[i].Database = &config.Database{Managed: true} } } - } + }) }, Enabled: func() bool { // Enabled if at least one database URI is available @@ -260,25 +303,31 @@ func TestReplay(t *testing.T) { } } - dbg := opts.DebugFromString(args.Env["SQLCDEBUG"]) + configFile := testctx.Mutate(t, path) cmdOpts := cmd.Options{ Env: cmd.Env{ - Debug: dbg, + Debug: opts.DebugFromString(args.Env["SQLCDEBUG"]), Experiment: opts.ExperimentFromString(args.Env["SQLCEXPERIMENT"]), }, - Stderr: &stderr, - MutateConfig: testctx.Mutate(t, path), + Stderr: &stderr, } switch args.Command { case "diff": - err = cmd.Diff(ctx, path, "", &cmdOpts) + res := api.Generate(ctx, api.GenerateOptions{ + Dir: path, + File: configFile, + Stderr: &stderr, + Diff: true, + }) + if len(res.Errors) > 0 { + err = res.Errors[0] + } case "generate": res := api.Generate(ctx, api.GenerateOptions{ - Dir: path, - Stderr: &stderr, - DisableProcessPlugins: !dbg.ProcessPlugins, - MutateConfig: testctx.Mutate(t, path), + Dir: path, + File: configFile, + Stderr: &stderr, }) output = res.Files if len(res.Errors) > 0 { @@ -288,7 +337,7 @@ func TestReplay(t *testing.T) { cmpDirectory(t, path, output) } case "vet": - err = cmd.Vet(ctx, path, "", &cmdOpts) + err = cmd.Vet(ctx, path, configFile, &cmdOpts) default: t.Fatalf("unknown command") } @@ -332,6 +381,10 @@ func cmpDirectory(t *testing.T, dir string, actual map[string]string) { if filepath.Base(path) == "exec.json" { return nil } + // Mutated configs written by writeMutatedConfig. + if strings.HasPrefix(filepath.Base(path), "sqlc.test-") { + return nil + } if strings.Contains(path, "/kotlin/build") { return nil } From 963f6265b2e6ec0dd460a7734f6a91ae81228b5a Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 28 Apr 2026 12:40:12 +0000 Subject: [PATCH 3/4] Gate process plugins via api.GenerateOptions.InsecureProcessPluginNames MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add an explicit allowlist of process-based plugin names to api.GenerateOptions. Generate fails before any parse or codegen runs if the configuration declares a process plugin whose name is not in the list. The "Insecure" prefix mirrors crypto/tls.Config.InsecureSkipVerify to flag the trust decision callers are making — process plugins execute arbitrary local commands. The CLI populates the allowlist by scanning the user's own config for declared process plugins, so `sqlc generate`, `sqlc compile`, and `sqlc diff` keep working. SQLCDEBUG=processplugins=0 still disables process plugins by leaving the allowlist nil. https://claude.ai/code/session_01RCzB2JR5Y5ScFDUmwcxGVZ --- internal/api/generate.go | 22 ++++++++++++++++++++ internal/cmd/cmd.go | 43 ++++++++++++++++++++++++++++++---------- internal/cmd/generate.go | 19 ++++++++++++++++++ 3 files changed, 73 insertions(+), 11 deletions(-) diff --git a/internal/api/generate.go b/internal/api/generate.go index 7321e08445..04da166bba 100644 --- a/internal/api/generate.go +++ b/internal/api/generate.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "path/filepath" + "slices" "strings" "sync" @@ -33,6 +34,15 @@ type GenerateOptions struct { // on disk and writes a unified diff for differences to Stderr. If any // differences are found, an error is appended to GenerateResult.Errors. Diff bool + + // InsecureProcessPluginNames is the allowlist of process-based plugin + // names that Generate is permitted to invoke. Any process plugin declared + // in the configuration whose name is not in this list causes Generate to + // fail before parsing or codegen runs. Process plugins execute arbitrary + // local commands; the "Insecure" prefix mirrors + // crypto/tls.Config.InsecureSkipVerify as a reminder that callers must + // consciously trust each plugin name they pass here. + InsecureProcessPluginNames []string } // GenerateResult is the outcome of a Generate call. Files maps absolute output @@ -72,6 +82,18 @@ func Generate(ctx context.Context, opts GenerateOptions) GenerateResult { return res } + for _, plug := range conf.Plugins { + if plug.Process == nil { + continue + } + if !slices.Contains(opts.InsecureProcessPluginNames, plug.Name) { + err := fmt.Errorf("process plugin %q is not in InsecureProcessPluginNames; refusing to run", plug.Name) + fmt.Fprintf(stderr, "error validating %s: %s\n", base, err) + res.Errors = append(res.Errors, err) + return res + } + } + g := &generator{ dir: opts.Dir, output: map[string]string{}, diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go index 6e5ff72bc7..7519559c13 100644 --- a/internal/cmd/cmd.go +++ b/internal/cmd/cmd.go @@ -182,6 +182,21 @@ func getConfigPath(stderr io.Writer, f *pflag.Flag) (string, string) { } } +// allowedProcessPluginNames returns the set of process plugin names the CLI +// trusts to run. SQLCDEBUG=processplugins=0 disables every process plugin by +// returning nil; otherwise we trust whatever the user declared in their own +// config. +func allowedProcessPluginNames(env Env, stderr io.Writer, dir, name string) []string { + if !env.Debug.ProcessPlugins { + return nil + } + names, err := processPluginNames(stderr, dir, name) + if err != nil { + os.Exit(1) + } + return names +} + var genCmd = &cobra.Command{ Use: "generate", Short: "Generate source code from SQL", @@ -189,11 +204,13 @@ var genCmd = &cobra.Command{ defer trace.StartRegion(cmd.Context(), "generate").End() stderr := cmd.ErrOrStderr() dir, name := getConfigPath(stderr, cmd.Flag("file")) + env := ParseEnv(cmd) res := api.Generate(cmd.Context(), api.GenerateOptions{ - Dir: dir, - File: name, - Stderr: stderr, - Write: true, + Dir: dir, + File: name, + Stderr: stderr, + Write: true, + InsecureProcessPluginNames: allowedProcessPluginNames(env, stderr, dir, name), }) if len(res.Errors) > 0 { os.Exit(1) @@ -209,10 +226,12 @@ var checkCmd = &cobra.Command{ defer trace.StartRegion(cmd.Context(), "compile").End() stderr := cmd.ErrOrStderr() dir, name := getConfigPath(stderr, cmd.Flag("file")) + env := ParseEnv(cmd) res := api.Generate(cmd.Context(), api.GenerateOptions{ - Dir: dir, - File: name, - Stderr: stderr, + Dir: dir, + File: name, + Stderr: stderr, + InsecureProcessPluginNames: allowedProcessPluginNames(env, stderr, dir, name), }) if len(res.Errors) > 0 { os.Exit(1) @@ -228,11 +247,13 @@ var diffCmd = &cobra.Command{ defer trace.StartRegion(cmd.Context(), "diff").End() stderr := cmd.ErrOrStderr() dir, name := getConfigPath(stderr, cmd.Flag("file")) + env := ParseEnv(cmd) res := api.Generate(cmd.Context(), api.GenerateOptions{ - Dir: dir, - File: name, - Stderr: stderr, - Diff: true, + Dir: dir, + File: name, + Stderr: stderr, + Diff: true, + InsecureProcessPluginNames: allowedProcessPluginNames(env, stderr, dir, name), }) if len(res.Errors) > 0 { os.Exit(1) diff --git a/internal/cmd/generate.go b/internal/cmd/generate.go index 1785a40718..e45d193543 100644 --- a/internal/cmd/generate.go +++ b/internal/cmd/generate.go @@ -106,6 +106,25 @@ func readConfig(stderr io.Writer, dir, filename string) (string, *config.Config, return configPath, &conf, nil } +// processPluginNames returns the names of every process-based plugin declared +// in the sqlc configuration at dir/filename. The CLI passes the result to +// api.GenerateOptions.InsecureProcessPluginNames so commands run by the user +// (who wrote the config) can invoke any plugin they declared, while library +// callers are still required to opt in explicitly. +func processPluginNames(stderr io.Writer, dir, filename string) ([]string, error) { + _, conf, err := readConfig(stderr, dir, filename) + if err != nil { + return nil, err + } + var names []string + for _, p := range conf.Plugins { + if p.Process != nil { + names = append(names, p.Name) + } + } + return names, nil +} + func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.CombinedSettings, parserOpts opts.Parser, stderr io.Writer) (*compiler.Result, bool) { defer trace.StartRegion(ctx, "parse").End() c, err := compiler.NewCompiler(sql, combo, parserOpts) From 6b2d30d96b91d5cbc098f7efbfedb0a5799cfa43 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 28 Apr 2026 15:07:13 +0000 Subject: [PATCH 4/4] Replace api.GenerateOptions Dir/File with Config io.Reader MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The struct collapses to five fields: Config (io.Reader), Stderr, Write, Diff, InsecureProcessPluginNames. api.Generate parses the config from the reader and treats every relative path in it as relative to the current working directory. CLI: each command opens the config file, reads its bytes, parses it once to extract declared process-plugin names, then chdirs to the config's directory before invoking api.Generate. Single-process so chdir is fine. Tests: a new mutatedConfigBytes helper parses the test's sqlc.yaml, forces version "2", rewrites every schema/queries/output path to be absolute relative to the test directory, and re-encodes as YAML — so api.Generate works without knowing the source directory. Optional mutate callback applies extra changes (managed-db servers etc.) and also drops a temp file alongside the original for cmd.Vet which still takes a config path. cmd/process.go and cmd/vet.go now skip joining their dir parameter when the config-supplied path is already absolute. KNOWN ISSUE: TestReplay parse-error tests and the diff_output tests fail because the api now emits absolute paths in error messages and unified-diff labels (no config-dir context to strip). Either add a BaseDir hint back to GenerateOptions or update the affected test expectations to match. https://claude.ai/code/session_01RCzB2JR5Y5ScFDUmwcxGVZ --- internal/api/config.go | 93 -------------------- internal/api/diff.go | 15 +++- internal/api/generate.go | 99 +++++++++++++-------- internal/api/parse.go | 19 ++-- internal/api/process.go | 31 ++++--- internal/cmd/cmd.go | 70 +++++++++++---- internal/cmd/generate.go | 19 ---- internal/cmd/process.go | 10 ++- internal/cmd/vet.go | 12 ++- internal/endtoend/endtoend_test.go | 134 +++++++++++++++++++++-------- 10 files changed, 272 insertions(+), 230 deletions(-) delete mode 100644 internal/api/config.go diff --git a/internal/api/config.go b/internal/api/config.go deleted file mode 100644 index 00eef38d19..0000000000 --- a/internal/api/config.go +++ /dev/null @@ -1,93 +0,0 @@ -package api - -import ( - "errors" - "fmt" - "io" - "os" - "path/filepath" - - "github.com/sqlc-dev/sqlc/internal/config" -) - -const errMessageNoVersion = `The configuration file must have a version number. -Set the version to 1 or 2 at the top of sqlc.json: - -{ - "version": "1" - ... -} -` - -const errMessageUnknownVersion = `The configuration file has an invalid version number. -The supported version can only be "1" or "2". -` - -const errMessageNoPackages = `No packages are configured` - -func readConfig(stderr io.Writer, dir, filename string) (string, *config.Config, error) { - configPath := "" - if filename != "" { - configPath = filepath.Join(dir, filename) - } else { - var yamlMissing, jsonMissing, ymlMissing bool - yamlPath := filepath.Join(dir, "sqlc.yaml") - ymlPath := filepath.Join(dir, "sqlc.yml") - jsonPath := filepath.Join(dir, "sqlc.json") - - if _, err := os.Stat(yamlPath); os.IsNotExist(err) { - yamlMissing = true - } - if _, err := os.Stat(jsonPath); os.IsNotExist(err) { - jsonMissing = true - } - - if _, err := os.Stat(ymlPath); os.IsNotExist(err) { - ymlMissing = true - } - - if yamlMissing && ymlMissing && jsonMissing { - fmt.Fprintln(stderr, "error parsing configuration files. sqlc.(yaml|yml) or sqlc.json: file does not exist") - return "", nil, errors.New("config file missing") - } - - if (!yamlMissing || !ymlMissing) && !jsonMissing { - fmt.Fprintln(stderr, "error: both sqlc.json and sqlc.(yaml|yml) files present") - return "", nil, errors.New("sqlc.json and sqlc.(yaml|yml) present") - } - - if jsonMissing { - if yamlMissing { - configPath = ymlPath - } else { - configPath = yamlPath - } - } else { - configPath = jsonPath - } - } - - base := filepath.Base(configPath) - file, err := os.Open(configPath) - if err != nil { - fmt.Fprintf(stderr, "error parsing %s: file does not exist\n", base) - return "", nil, err - } - defer file.Close() - - conf, err := config.ParseConfig(file) - if err != nil { - switch err { - case config.ErrMissingVersion: - fmt.Fprint(stderr, errMessageNoVersion) - case config.ErrUnknownVersion: - fmt.Fprint(stderr, errMessageUnknownVersion) - case config.ErrNoPackages: - fmt.Fprint(stderr, errMessageNoPackages) - } - fmt.Fprintf(stderr, "error parsing %s: %s\n", base, err) - return "", nil, err - } - - return configPath, &conf, nil -} diff --git a/internal/api/diff.go b/internal/api/diff.go index 9c07fd1456..5e70797fc9 100644 --- a/internal/api/diff.go +++ b/internal/api/diff.go @@ -11,7 +11,6 @@ import ( "path/filepath" "runtime/trace" "sort" - "strings" "github.com/cubicdaiya/gonp" ) @@ -31,10 +30,12 @@ func writeFiles(ctx context.Context, files map[string]string, stderr io.Writer) return nil } -func diffFiles(ctx context.Context, dir string, files map[string]string, stderr io.Writer) error { +func diffFiles(ctx context.Context, files map[string]string, stderr io.Writer) error { defer trace.StartRegion(ctx, "checkfiles").End() var errored bool + wd, _ := os.Getwd() + keys := make([]string, 0, len(files)) for k := range files { keys = append(keys, k) @@ -59,8 +60,14 @@ func diffFiles(ctx context.Context, dir string, files map[string]string, stderr if len(uniHunks) > 0 { errored = true - fmt.Fprintf(stderr, "--- a%s\n", strings.TrimPrefix(filename, dir)) - fmt.Fprintf(stderr, "+++ b%s\n", strings.TrimPrefix(filename, dir)) + label := filename + if wd != "" { + if rel, err := filepath.Rel(wd, filename); err == nil { + label = "/" + rel + } + } + fmt.Fprintf(stderr, "--- a%s\n", label) + fmt.Fprintf(stderr, "+++ b%s\n", label) d.FprintUniHunks(stderr, uniHunks) } } diff --git a/internal/api/generate.go b/internal/api/generate.go index 04da166bba..8c7bd6efce 100644 --- a/internal/api/generate.go +++ b/internal/api/generate.go @@ -2,6 +2,7 @@ package api import ( "context" + "errors" "fmt" "io" "path/filepath" @@ -13,26 +14,24 @@ import ( "github.com/sqlc-dev/sqlc/internal/config" ) -// GenerateOptions controls a single Generate invocation. +// GenerateOptions controls a single Generate invocation. Paths declared in the +// configuration are resolved relative to the current working directory, so +// callers wanting a different base directory should either pass absolute +// paths in the config or os.Chdir before calling. type GenerateOptions struct { - // Dir is the working directory used to resolve the config file and any - // relative schema/query paths within it. - Dir string - - // File is the configuration filename to use, relative to Dir. When empty, - // Generate looks for sqlc.yaml, sqlc.yml, or sqlc.json in Dir. - File string + // Config is the sqlc configuration as a YAML or JSON document. Required. + Config io.Reader // Stderr receives diagnostic output. If nil, output is discarded. Stderr io.Writer - // Write, when true, writes the generated files to disk after a successful - // generate. Failures are reported via GenerateResult.Errors. + // Write writes the generated files to disk after a successful generate. + // Failures are reported via GenerateResult.Errors. Write bool - // Diff, when true, compares each generated file against any existing file - // on disk and writes a unified diff for differences to Stderr. If any - // differences are found, an error is appended to GenerateResult.Errors. + // Diff compares each generated file against any existing file on disk and + // writes a unified diff for differences to Stderr. If any differences are + // found, an error is appended to GenerateResult.Errors. Diff bool // InsecureProcessPluginNames is the allowlist of process-based plugin @@ -45,9 +44,7 @@ type GenerateOptions struct { InsecureProcessPluginNames []string } -// GenerateResult is the outcome of a Generate call. Files maps absolute output -// paths to file contents; callers are responsible for writing them to disk if -// desired. Errors collects any errors encountered during code generation. +// GenerateResult is the outcome of a Generate call. type GenerateResult struct { // Files maps absolute output paths to generated file contents. Files map[string]string @@ -58,9 +55,7 @@ type GenerateResult struct { } // Generate parses the sqlc configuration referenced by opts and runs every -// configured codegen target. The returned GenerateResult always has a non-nil -// Files map; the map is empty when generation fails before any files are -// produced. +// configured codegen target. func Generate(ctx context.Context, opts GenerateOptions) GenerateResult { stderr := opts.Stderr if stderr == nil { @@ -69,15 +64,30 @@ func Generate(ctx context.Context, opts GenerateOptions) GenerateResult { res := GenerateResult{Files: map[string]string{}} - configPath, conf, err := readConfig(stderr, opts.Dir, opts.File) + if opts.Config == nil { + err := errors.New("GenerateOptions.Config is required") + fmt.Fprintln(stderr, err) + res.Errors = append(res.Errors, err) + return res + } + + conf, err := config.ParseConfig(opts.Config) if err != nil { + switch err { + case config.ErrMissingVersion: + fmt.Fprint(stderr, errMessageNoVersion) + case config.ErrUnknownVersion: + fmt.Fprint(stderr, errMessageUnknownVersion) + case config.ErrNoPackages: + fmt.Fprint(stderr, errMessageNoPackages) + } + fmt.Fprintf(stderr, "error parsing config: %s\n", err) res.Errors = append(res.Errors, err) return res } - base := filepath.Base(configPath) - if err := config.Validate(conf); err != nil { - fmt.Fprintf(stderr, "error validating %s: %s\n", base, err) + if err := config.Validate(&conf); err != nil { + fmt.Fprintf(stderr, "error validating config: %s\n", err) res.Errors = append(res.Errors, err) return res } @@ -88,18 +98,15 @@ func Generate(ctx context.Context, opts GenerateOptions) GenerateResult { } if !slices.Contains(opts.InsecureProcessPluginNames, plug.Name) { err := fmt.Errorf("process plugin %q is not in InsecureProcessPluginNames; refusing to run", plug.Name) - fmt.Fprintf(stderr, "error validating %s: %s\n", base, err) + fmt.Fprintf(stderr, "error validating config: %s\n", err) res.Errors = append(res.Errors, err) return res } } - g := &generator{ - dir: opts.Dir, - output: map[string]string{}, - } + g := &generator{output: map[string]string{}} - if err := processQuerySets(ctx, g, conf, opts.Dir, stderr); err != nil { + if err := processQuerySets(ctx, g, &conf, stderr); err != nil { res.Errors = append(res.Errors, err) return res } @@ -113,7 +120,7 @@ func Generate(ctx context.Context, opts GenerateOptions) GenerateResult { } if opts.Diff { - if err := diffFiles(ctx, opts.Dir, res.Files, stderr); err != nil { + if err := diffFiles(ctx, res.Files, stderr); err != nil { res.Errors = append(res.Errors, err) } } @@ -121,9 +128,23 @@ func Generate(ctx context.Context, opts GenerateOptions) GenerateResult { return res } +const errMessageNoVersion = `The configuration must have a version number. +Set the version to 1 or 2 at the top of the config: + +{ + "version": "1" + ... +} +` + +const errMessageUnknownVersion = `The configuration has an invalid version number. +The supported version can only be "1" or "2". +` + +const errMessageNoPackages = `No packages are configured` + type generator struct { m sync.Mutex - dir string output map[string]string } @@ -162,23 +183,25 @@ func (g *generator) ProcessResult(ctx context.Context, combo config.CombinedSett files[file.Name] = string(file.Contents) } g.m.Lock() + defer g.m.Unlock() - // out is specified by the user, not a plugin - absout := filepath.Join(g.dir, out) + absout, err := filepath.Abs(out) + if err != nil { + return err + } for n, source := range files { - filename := filepath.Join(g.dir, out, n) - // filepath.Join calls filepath.Clean which should remove all "..", but - // double check to make sure + filename, err := filepath.Abs(filepath.Join(out, n)) + if err != nil { + return err + } if strings.Contains(filename, "..") { return fmt.Errorf("invalid file output path: %s", filename) } - // The output file must be contained inside the output directory if !strings.HasPrefix(filename, absout) { return fmt.Errorf("invalid file output path: %s", filename) } g.output[filename] = source } - g.m.Unlock() return nil } diff --git a/internal/api/parse.go b/internal/api/parse.go index d2487eebea..cb4406b77e 100644 --- a/internal/api/parse.go +++ b/internal/api/parse.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "os" "path/filepath" "runtime/trace" @@ -14,15 +15,21 @@ import ( "github.com/sqlc-dev/sqlc/internal/opts" ) -func printFileErr(stderr io.Writer, dir string, fileErr *multierr.FileError) { - filename, err := filepath.Rel(dir, fileErr.Filename) +func printFileErr(stderr io.Writer, fileErr *multierr.FileError) { + wd, err := os.Getwd() if err != nil { - filename = fileErr.Filename + wd = "" + } + filename := fileErr.Filename + if wd != "" { + if rel, err := filepath.Rel(wd, fileErr.Filename); err == nil { + filename = rel + } } fmt.Fprintf(stderr, "%s:%d:%d: %s\n", filename, fileErr.Line, fileErr.Column, fileErr.Err) } -func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.CombinedSettings, parserOpts opts.Parser, stderr io.Writer) (*compiler.Result, bool) { +func parse(ctx context.Context, name string, sql config.SQL, combo config.CombinedSettings, parserOpts opts.Parser, stderr io.Writer) (*compiler.Result, bool) { defer trace.StartRegion(ctx, "parse").End() c, err := compiler.NewCompiler(sql, combo, parserOpts) defer func() { @@ -38,7 +45,7 @@ func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.C fmt.Fprintf(stderr, "# package %s\n", name) if parserErr, ok := err.(*multierr.Error); ok { for _, fileErr := range parserErr.Errs() { - printFileErr(stderr, dir, fileErr) + printFileErr(stderr, fileErr) } } else { fmt.Fprintf(stderr, "error parsing schema: %s\n", err) @@ -52,7 +59,7 @@ func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.C fmt.Fprintf(stderr, "# package %s\n", name) if parserErr, ok := err.(*multierr.Error); ok { for _, fileErr := range parserErr.Errs() { - printFileErr(stderr, dir, fileErr) + printFileErr(stderr, fileErr) } } else { fmt.Fprintf(stderr, "error parsing queries: %s\n", err) diff --git a/internal/api/process.go b/internal/api/process.go index 95d2c46e1e..b4051ddb3e 100644 --- a/internal/api/process.go +++ b/internal/api/process.go @@ -29,7 +29,7 @@ type resultProcessor interface { ProcessResult(context.Context, config.CombinedSettings, outputPair, *compiler.Result) error } -func processQuerySets(ctx context.Context, rp resultProcessor, conf *config.Config, dir string, stderr io.Writer) error { +func processQuerySets(ctx context.Context, rp resultProcessor, conf *config.Config, stderr io.Writer) error { errored := false pairs := rp.Pairs(ctx, conf) @@ -48,18 +48,29 @@ func processQuerySets(ctx context.Context, rp resultProcessor, conf *config.Conf combo.Codegen = *sql.Plugin } - // TODO: This feels like a hack that will bite us later - joined := make([]string, 0, len(sql.Schema)) + absSchema := make([]string, 0, len(sql.Schema)) for _, s := range sql.Schema { - joined = append(joined, filepath.Join(dir, s)) + abs, err := filepath.Abs(s) + if err != nil { + fmt.Fprintf(errout, "resolve schema path %s: %s\n", s, err) + errored = true + return nil + } + absSchema = append(absSchema, abs) } - sql.Schema = joined + sql.Schema = absSchema - joined = make([]string, 0, len(sql.Queries)) + absQueries := make([]string, 0, len(sql.Queries)) for _, q := range sql.Queries { - joined = append(joined, filepath.Join(dir, q)) + abs, err := filepath.Abs(q) + if err != nil { + fmt.Fprintf(errout, "resolve query path %s: %s\n", q, err) + errored = true + return nil + } + absQueries = append(absQueries, abs) } - sql.Queries = joined + sql.Queries = absQueries var name, lang string parseOpts := opts.Parser{ @@ -77,9 +88,9 @@ func processQuerySets(ctx context.Context, rp resultProcessor, conf *config.Conf } packageRegion := trace.StartRegion(gctx, "package") - trace.Logf(gctx, "", "name=%s dir=%s plugin=%s", name, dir, lang) + trace.Logf(gctx, "", "name=%s plugin=%s", name, lang) - result, failed := parse(gctx, name, dir, sql.SQL, combo, parseOpts, errout) + result, failed := parse(gctx, name, sql.SQL, combo, parseOpts, errout) if failed { packageRegion.End() errored = true diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go index 7519559c13..0f37288419 100644 --- a/internal/cmd/cmd.go +++ b/internal/cmd/cmd.go @@ -1,6 +1,7 @@ package cmd import ( + "bytes" "context" "errors" "fmt" @@ -182,19 +183,52 @@ func getConfigPath(stderr io.Writer, f *pflag.Flag) (string, string) { } } -// allowedProcessPluginNames returns the set of process plugin names the CLI -// trusts to run. SQLCDEBUG=processplugins=0 disables every process plugin by -// returning nil; otherwise we trust whatever the user declared in their own -// config. -func allowedProcessPluginNames(env Env, stderr io.Writer, dir, name string) []string { - if !env.Debug.ProcessPlugins { - return nil +// loadConfig opens the sqlc config and reads it into memory. It also chdirs +// the process to the config's directory so that relative paths declared in the +// config resolve correctly when api.Generate is called. Returns the config +// bytes and the list of process plugin names declared in the config (used to +// populate api.GenerateOptions.InsecureProcessPluginNames). +func loadConfig(stderr io.Writer, dir, name string) ([]byte, []string) { + configPath, _, err := readConfig(stderr, dir, name) + if err != nil { + os.Exit(1) + } + configPath, err = filepath.Abs(configPath) + if err != nil { + fmt.Fprintf(stderr, "error resolving config path: %s\n", err) + os.Exit(1) + } + data, err := os.ReadFile(configPath) + if err != nil { + fmt.Fprintf(stderr, "error reading %s: %s\n", configPath, err) + os.Exit(1) } - names, err := processPluginNames(stderr, dir, name) + conf, err := config.ParseConfig(bytes.NewReader(data)) if err != nil { + fmt.Fprintf(stderr, "error parsing %s: %s\n", configPath, err) os.Exit(1) } - return names + if err := os.Chdir(filepath.Dir(configPath)); err != nil { + fmt.Fprintf(stderr, "error changing directory: %s\n", err) + os.Exit(1) + } + var names []string + for _, p := range conf.Plugins { + if p.Process != nil { + names = append(names, p.Name) + } + } + return data, names +} + +// allowedProcessPluginNames returns the names that should populate +// api.GenerateOptions.InsecureProcessPluginNames. SQLCDEBUG=processplugins=0 +// disables every process plugin by returning nil. +func allowedProcessPluginNames(env Env, declared []string) []string { + if !env.Debug.ProcessPlugins { + return nil + } + return declared } var genCmd = &cobra.Command{ @@ -205,12 +239,12 @@ var genCmd = &cobra.Command{ stderr := cmd.ErrOrStderr() dir, name := getConfigPath(stderr, cmd.Flag("file")) env := ParseEnv(cmd) + data, declared := loadConfig(stderr, dir, name) res := api.Generate(cmd.Context(), api.GenerateOptions{ - Dir: dir, - File: name, + Config: bytes.NewReader(data), Stderr: stderr, Write: true, - InsecureProcessPluginNames: allowedProcessPluginNames(env, stderr, dir, name), + InsecureProcessPluginNames: allowedProcessPluginNames(env, declared), }) if len(res.Errors) > 0 { os.Exit(1) @@ -227,11 +261,11 @@ var checkCmd = &cobra.Command{ stderr := cmd.ErrOrStderr() dir, name := getConfigPath(stderr, cmd.Flag("file")) env := ParseEnv(cmd) + data, declared := loadConfig(stderr, dir, name) res := api.Generate(cmd.Context(), api.GenerateOptions{ - Dir: dir, - File: name, + Config: bytes.NewReader(data), Stderr: stderr, - InsecureProcessPluginNames: allowedProcessPluginNames(env, stderr, dir, name), + InsecureProcessPluginNames: allowedProcessPluginNames(env, declared), }) if len(res.Errors) > 0 { os.Exit(1) @@ -248,12 +282,12 @@ var diffCmd = &cobra.Command{ stderr := cmd.ErrOrStderr() dir, name := getConfigPath(stderr, cmd.Flag("file")) env := ParseEnv(cmd) + data, declared := loadConfig(stderr, dir, name) res := api.Generate(cmd.Context(), api.GenerateOptions{ - Dir: dir, - File: name, + Config: bytes.NewReader(data), Stderr: stderr, Diff: true, - InsecureProcessPluginNames: allowedProcessPluginNames(env, stderr, dir, name), + InsecureProcessPluginNames: allowedProcessPluginNames(env, declared), }) if len(res.Errors) > 0 { os.Exit(1) diff --git a/internal/cmd/generate.go b/internal/cmd/generate.go index e45d193543..1785a40718 100644 --- a/internal/cmd/generate.go +++ b/internal/cmd/generate.go @@ -106,25 +106,6 @@ func readConfig(stderr io.Writer, dir, filename string) (string, *config.Config, return configPath, &conf, nil } -// processPluginNames returns the names of every process-based plugin declared -// in the sqlc configuration at dir/filename. The CLI passes the result to -// api.GenerateOptions.InsecureProcessPluginNames so commands run by the user -// (who wrote the config) can invoke any plugin they declared, while library -// callers are still required to opt in explicitly. -func processPluginNames(stderr io.Writer, dir, filename string) ([]string, error) { - _, conf, err := readConfig(stderr, dir, filename) - if err != nil { - return nil, err - } - var names []string - for _, p := range conf.Plugins { - if p.Process != nil { - names = append(names, p.Name) - } - } - return names, nil -} - func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.CombinedSettings, parserOpts opts.Parser, stderr io.Writer) (*compiler.Result, bool) { defer trace.StartRegion(ctx, "parse").End() c, err := compiler.NewCompiler(sql, combo, parserOpts) diff --git a/internal/cmd/process.go b/internal/cmd/process.go index 5003d113b8..c5e3286299 100644 --- a/internal/cmd/process.go +++ b/internal/cmd/process.go @@ -74,15 +74,21 @@ func processQuerySets(ctx context.Context, rp ResultProcessor, conf *config.Conf } // TODO: This feels like a hack that will bite us later + joinDir := func(p string) string { + if filepath.IsAbs(p) { + return p + } + return filepath.Join(dir, p) + } joined := make([]string, 0, len(sql.Schema)) for _, s := range sql.Schema { - joined = append(joined, filepath.Join(dir, s)) + joined = append(joined, joinDir(s)) } sql.Schema = joined joined = make([]string, 0, len(sql.Queries)) for _, q := range sql.Queries { - joined = append(joined, filepath.Join(dir, q)) + joined = append(joined, joinDir(q)) } sql.Queries = joined diff --git a/internal/cmd/vet.go b/internal/cmd/vet.go index 4dbd3c3b7b..3d70d9a301 100644 --- a/internal/cmd/vet.go +++ b/internal/cmd/vet.go @@ -467,15 +467,21 @@ func (c *checker) checkSQL(ctx context.Context, s config.SQL) error { combo := config.Combine(*c.Conf, s) // TODO: This feels like a hack that will bite us later + joinDir := func(p string) string { + if filepath.IsAbs(p) { + return p + } + return filepath.Join(c.Dir, p) + } joined := make([]string, 0, len(s.Schema)) - for _, s := range s.Schema { - joined = append(joined, filepath.Join(c.Dir, s)) + for _, p := range s.Schema { + joined = append(joined, joinDir(p)) } s.Schema = joined joined = make([]string, 0, len(s.Queries)) for _, q := range s.Queries { - joined = append(joined, filepath.Join(c.Dir, q)) + joined = append(joined, joinDir(q)) } s.Queries = joined diff --git a/internal/endtoend/endtoend_test.go b/internal/endtoend/endtoend_test.go index 68cc037225..9bcb36d482 100644 --- a/internal/endtoend/endtoend_test.go +++ b/internal/endtoend/endtoend_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "io" "os" osexec "os/exec" "path/filepath" @@ -62,7 +63,7 @@ func TestExamples(t *testing.T) { path := filepath.Join(examples, tc) var stderr bytes.Buffer res := api.Generate(ctx, api.GenerateOptions{ - Dir: path, + Config: openConfigReader(t, path), Stderr: &stderr, }) if len(res.Errors) > 0 { @@ -90,10 +91,11 @@ func BenchmarkExamples(b *testing.B) { tc := replay.Name() b.Run(tc, func(b *testing.B) { path := filepath.Join(examples, tc) + cfg := openConfigBytes(b, path) for i := 0; i < b.N; i++ { var stderr bytes.Buffer api.Generate(ctx, api.GenerateOptions{ - Dir: path, + Config: bytes.NewReader(cfg), Stderr: &stderr, }) } @@ -101,51 +103,94 @@ func BenchmarkExamples(b *testing.B) { } } -// textContext describes a TestReplay scenario. Mutate returns the config -// filename (relative to the test directory) that should be passed to the -// command under test. The "base" context returns "" to use the project's -// existing sqlc config; the "managed-db" context writes a mutated copy of the -// config to a temporary file inside the test directory and returns its name. -type textContext struct { - Mutate func(*testing.T, string) string - Enabled func() bool +// openConfigReader reads the sqlc config in dir, rewrites every relative +// schema/queries/output path to an absolute one (so api.Generate doesn't have +// to know the config's directory), and returns the result as an io.Reader. +func openConfigReader(t testing.TB, dir string) io.Reader { + return bytes.NewReader(openConfigBytes(t, dir)) +} + +func openConfigBytes(t testing.TB, dir string) []byte { + t.Helper() + data, _ := mutatedConfigBytes(t, dir, nil) + return data } -// writeMutatedConfig parses the sqlc config in dir, applies mutate to the -// in-memory Config (which is always v2-shaped, even when the file on disk is -// v1), forces version "2", and writes the result to a temp file alongside the -// original. The temp file is removed when the test ends. -func writeMutatedConfig(t *testing.T, dir string, mutate func(*config.Config)) string { +// mutatedConfigBytes parses the sqlc config in dir, applies mutate (when +// non-nil) to the in-memory Config, makes every path absolute relative to dir, +// and re-encodes as YAML. Parsing v1 configs converts them to a v2-shaped +// Config; we force version "2" so the result can be parsed back by api.Generate. +// +// When mutate is non-nil, the encoded bytes are also written to a temp file +// alongside the original (cleaned up at test end) and the filename is returned +// so callers like cmd.Vet that still take a config-file path can use it. +func mutatedConfigBytes(t testing.TB, dir string, mutate func(*config.Config)) ([]byte, string) { t.Helper() original, conf, err := readSqlcConfig(dir) if err != nil { t.Fatalf("read sqlc config from %s: %s", dir, err) } - // Parsing v1 configs converts them to a v2-shaped Config. Force version "2" - // so the mutated config can be re-parsed as v2 from disk. conf.Version = "2" - mutate(conf) + absolutizePaths(conf, dir) + if mutate != nil { + mutate(conf) + } + + var buf bytes.Buffer + enc := yaml.NewEncoder(&buf) + if err := enc.Encode(conf); err != nil { + t.Fatalf("encode config: %s", err) + } + if err := enc.Close(); err != nil { + t.Fatalf("close yaml encoder: %s", err) + } + data := buf.Bytes() + + if mutate == nil { + return data, "" + } f, err := os.CreateTemp(dir, "sqlc.test-*"+filepath.Ext(original)) if err != nil { t.Fatalf("create temp config in %s: %s", dir, err) } t.Cleanup(func() { os.Remove(f.Name()) }) - - enc := yaml.NewEncoder(f) - if err := enc.Encode(conf); err != nil { + if _, err := f.Write(data); err != nil { f.Close() t.Fatalf("write temp config %s: %s", f.Name(), err) } - if err := enc.Close(); err != nil { - f.Close() - t.Fatalf("close yaml encoder for %s: %s", f.Name(), err) - } if err := f.Close(); err != nil { t.Fatalf("close temp config %s: %s", f.Name(), err) } - return filepath.Base(f.Name()) + return data, filepath.Base(f.Name()) +} + +func absolutizePaths(conf *config.Config, dir string) { + abs := func(p string) string { + if p == "" || filepath.IsAbs(p) { + return p + } + return filepath.Join(dir, p) + } + for i := range conf.SQL { + s := &conf.SQL[i] + for j, p := range s.Schema { + s.Schema[j] = abs(p) + } + for j, p := range s.Queries { + s.Queries[j] = abs(p) + } + if s.Gen.Go != nil { + s.Gen.Go.Out = abs(s.Gen.Go.Out) + } + if s.Gen.JSON != nil { + s.Gen.JSON.Out = abs(s.Gen.JSON.Out) + } + for j := range s.Codegen { + s.Codegen[j].Out = abs(s.Codegen[j].Out) + } + } } func readSqlcConfig(dir string) (string, *config.Config, error) { @@ -168,6 +213,18 @@ func readSqlcConfig(dir string) (string, *config.Config, error) { return "", nil, fmt.Errorf("no sqlc config found in %s", dir) } +// configRef is the result of preparing a config for a single TestReplay case. +// reader is for api.Generate; file is for cmd.Vet which still takes a path. +type configRef struct { + reader io.Reader + file string +} + +type textContext struct { + Config func(*testing.T, string) configRef + Enabled func() bool +} + func TestReplay(t *testing.T) { // Ensure that this environment variable is always set to true when running // end-to-end tests @@ -236,12 +293,15 @@ func TestReplay(t *testing.T) { contexts := map[string]textContext{ "base": { - Mutate: func(t *testing.T, path string) string { return "" }, + Config: func(t *testing.T, dir string) configRef { + data, _ := mutatedConfigBytes(t, dir, nil) + return configRef{reader: bytes.NewReader(data)} + }, Enabled: func() bool { return true }, }, "managed-db": { - Mutate: func(t *testing.T, path string) string { - return writeMutatedConfig(t, path, func(c *config.Config) { + Config: func(t *testing.T, dir string) configRef { + data, file := mutatedConfigBytes(t, dir, func(c *config.Config) { // Add all servers - tests will fail if database isn't available c.Servers = []config.Server{ {Name: "postgres", Engine: config.EnginePostgreSQL, URI: postgresURI}, @@ -254,6 +314,7 @@ func TestReplay(t *testing.T) { } } }) + return configRef{reader: bytes.NewReader(data), file: file} }, Enabled: func() bool { // Enabled if at least one database URI is available @@ -303,7 +364,7 @@ func TestReplay(t *testing.T) { } } - configFile := testctx.Mutate(t, path) + cfg := testctx.Config(t, path) cmdOpts := cmd.Options{ Env: cmd.Env{ Debug: opts.DebugFromString(args.Env["SQLCDEBUG"]), @@ -315,8 +376,7 @@ func TestReplay(t *testing.T) { switch args.Command { case "diff": res := api.Generate(ctx, api.GenerateOptions{ - Dir: path, - File: configFile, + Config: cfg.reader, Stderr: &stderr, Diff: true, }) @@ -325,8 +385,7 @@ func TestReplay(t *testing.T) { } case "generate": res := api.Generate(ctx, api.GenerateOptions{ - Dir: path, - File: configFile, + Config: cfg.reader, Stderr: &stderr, }) output = res.Files @@ -337,7 +396,7 @@ func TestReplay(t *testing.T) { cmpDirectory(t, path, output) } case "vet": - err = cmd.Vet(ctx, path, configFile, &cmdOpts) + err = cmd.Vet(ctx, path, cfg.file, &cmdOpts) default: t.Fatalf("unknown command") } @@ -381,7 +440,7 @@ func cmpDirectory(t *testing.T, dir string, actual map[string]string) { if filepath.Base(path) == "exec.json" { return nil } - // Mutated configs written by writeMutatedConfig. + // Mutated configs written by mutatedConfigBytes. if strings.HasPrefix(filepath.Base(path), "sqlc.test-") { return nil } @@ -447,10 +506,11 @@ func BenchmarkReplay(b *testing.B) { tc := replay b.Run(tc, func(b *testing.B) { path, _ := filepath.Abs(tc) + cfg := openConfigBytes(b, path) for i := 0; i < b.N; i++ { var stderr bytes.Buffer api.Generate(ctx, api.GenerateOptions{ - Dir: path, + Config: bytes.NewReader(cfg), Stderr: &stderr, }) }