diff --git a/internal/api/api.go b/internal/api/api.go new file mode 100644 index 0000000000..c1011de572 --- /dev/null +++ b/internal/api/api.go @@ -0,0 +1,11 @@ +// Package api is intended to be the future public API for sqlc. +// +// The shape of this package is inspired by esbuild's Build API +// (https://pkg.go.dev/github.com/evanw/esbuild/pkg/api#hdr-Build_API): a small +// surface area of options structs and result structs that lets callers drive +// sqlc programmatically without going through the CLI. +// +// Today the package lives under internal/ while the API stabilises. Once the +// surface settles it is expected to graduate to pkg/api so it can be imported +// by external Go programs. +package api diff --git a/internal/api/codegen.go b/internal/api/codegen.go new file mode 100644 index 0000000000..d733c3b2d4 --- /dev/null +++ b/internal/api/codegen.go @@ -0,0 +1,108 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "runtime/trace" + + "google.golang.org/grpc" + + "github.com/sqlc-dev/sqlc/internal/codegen/golang" + genjson "github.com/sqlc-dev/sqlc/internal/codegen/json" + "github.com/sqlc-dev/sqlc/internal/compiler" + "github.com/sqlc-dev/sqlc/internal/config" + "github.com/sqlc-dev/sqlc/internal/config/convert" + "github.com/sqlc-dev/sqlc/internal/ext" + "github.com/sqlc-dev/sqlc/internal/ext/process" + "github.com/sqlc-dev/sqlc/internal/ext/wasm" + "github.com/sqlc-dev/sqlc/internal/plugin" +) + +func findPlugin(conf config.Config, name string) (*config.Plugin, error) { + for _, plug := range conf.Plugins { + if plug.Name == name { + return &plug, nil + } + } + return nil, fmt.Errorf("plugin not found") +} + +func codegen(ctx context.Context, combo config.CombinedSettings, sql outputPair, result *compiler.Result) (string, *plugin.GenerateResponse, error) { + defer trace.StartRegion(ctx, "codegen").End() + req := codeGenRequest(result, combo) + var handler grpc.ClientConnInterface + var out string + switch { + case sql.Plugin != nil: + out = sql.Plugin.Out + plug, err := findPlugin(combo.Global, sql.Plugin.Plugin) + if err != nil { + return "", nil, fmt.Errorf("plugin not found: %s", err) + } + + switch { + case plug.Process != nil: + handler = &process.Runner{ + Cmd: plug.Process.Cmd, + Env: plug.Env, + Format: plug.Process.Format, + } + case plug.WASM != nil: + handler = &wasm.Runner{ + URL: plug.WASM.URL, + SHA256: plug.WASM.SHA256, + Env: plug.Env, + } + default: + return "", nil, fmt.Errorf("unsupported plugin type") + } + + opts, err := convert.YAMLtoJSON(sql.Plugin.Options) + if err != nil { + return "", nil, fmt.Errorf("invalid plugin options: %w", err) + } + req.PluginOptions = opts + + global, found := combo.Global.Options[plug.Name] + if found { + opts, err := convert.YAMLtoJSON(global) + if err != nil { + return "", nil, fmt.Errorf("invalid global options: %w", err) + } + req.GlobalOptions = opts + } + + case sql.Gen.Go != nil: + out = combo.Go.Out + handler = ext.HandleFunc(golang.Generate) + opts, err := json.Marshal(sql.Gen.Go) + if err != nil { + return "", nil, fmt.Errorf("opts marshal failed: %w", err) + } + req.PluginOptions = opts + + if combo.Global.Overrides.Go != nil { + opts, err := json.Marshal(combo.Global.Overrides.Go) + if err != nil { + return "", nil, fmt.Errorf("opts marshal failed: %w", err) + } + req.GlobalOptions = opts + } + + case sql.Gen.JSON != nil: + out = combo.JSON.Out + handler = ext.HandleFunc(genjson.Generate) + opts, err := json.Marshal(sql.Gen.JSON) + if err != nil { + return "", nil, fmt.Errorf("opts marshal failed: %w", err) + } + req.PluginOptions = opts + + default: + return "", nil, fmt.Errorf("missing language backend") + } + client := plugin.NewCodegenServiceClient(handler) + resp, err := client.Generate(ctx, req) + return out, resp, err +} diff --git a/internal/api/diff.go b/internal/api/diff.go new file mode 100644 index 0000000000..5e70797fc9 --- /dev/null +++ b/internal/api/diff.go @@ -0,0 +1,107 @@ +package api + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "runtime/trace" + "sort" + + "github.com/cubicdaiya/gonp" +) + +func writeFiles(ctx context.Context, files map[string]string, stderr io.Writer) error { + defer trace.StartRegion(ctx, "writefiles").End() + for filename, source := range files { + if err := os.MkdirAll(filepath.Dir(filename), 0755); err != nil { + fmt.Fprintf(stderr, "%s: %s\n", filename, err) + return err + } + if err := os.WriteFile(filename, []byte(source), 0644); err != nil { + fmt.Fprintf(stderr, "%s: %s\n", filename, err) + return err + } + } + return nil +} + +func diffFiles(ctx context.Context, files map[string]string, stderr io.Writer) error { + defer trace.StartRegion(ctx, "checkfiles").End() + var errored bool + + wd, _ := os.Getwd() + + keys := make([]string, 0, len(files)) + for k := range files { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, filename := range keys { + source := files[filename] + if _, err := os.Stat(filename); errors.Is(err, os.ErrNotExist) { + errored = true + continue + } + existing, err := os.ReadFile(filename) + if err != nil { + errored = true + fmt.Fprintf(stderr, "%s: %s\n", filename, err) + continue + } + d := gonp.New(getLines(existing), getLines([]byte(source))) + d.Compose() + uniHunks := filterHunks(d.UnifiedHunks()) + + if len(uniHunks) > 0 { + errored = true + label := filename + if wd != "" { + if rel, err := filepath.Rel(wd, filename); err == nil { + label = "/" + rel + } + } + fmt.Fprintf(stderr, "--- a%s\n", label) + fmt.Fprintf(stderr, "+++ b%s\n", label) + d.FprintUniHunks(stderr, uniHunks) + } + } + if errored { + return errors.New("diff found") + } + return nil +} + +func getLines(f []byte) []string { + fp := bytes.NewReader(f) + scanner := bufio.NewScanner(fp) + lines := make([]string, 0) + for scanner.Scan() { + lines = append(lines, scanner.Text()) + } + return lines +} + +func filterHunks[T gonp.Elem](uniHunks []gonp.UniHunk[T]) []gonp.UniHunk[T] { + var out []gonp.UniHunk[T] + for i, uniHunk := range uniHunks { + var changed bool + for _, e := range uniHunk.GetChanges() { + switch e.GetType() { + case gonp.SesDelete: + changed = true + case gonp.SesAdd: + changed = true + } + } + if changed { + out = append(out, uniHunks[i]) + } + } + return out +} diff --git a/internal/api/generate.go b/internal/api/generate.go new file mode 100644 index 0000000000..8c7bd6efce --- /dev/null +++ b/internal/api/generate.go @@ -0,0 +1,207 @@ +package api + +import ( + "context" + "errors" + "fmt" + "io" + "path/filepath" + "slices" + "strings" + "sync" + + "github.com/sqlc-dev/sqlc/internal/compiler" + "github.com/sqlc-dev/sqlc/internal/config" +) + +// GenerateOptions controls a single Generate invocation. Paths declared in the +// configuration are resolved relative to the current working directory, so +// callers wanting a different base directory should either pass absolute +// paths in the config or os.Chdir before calling. +type GenerateOptions struct { + // Config is the sqlc configuration as a YAML or JSON document. Required. + Config io.Reader + + // Stderr receives diagnostic output. If nil, output is discarded. + Stderr io.Writer + + // Write writes the generated files to disk after a successful generate. + // Failures are reported via GenerateResult.Errors. + Write bool + + // Diff compares each generated file against any existing file on disk and + // writes a unified diff for differences to Stderr. If any differences are + // found, an error is appended to GenerateResult.Errors. + Diff bool + + // InsecureProcessPluginNames is the allowlist of process-based plugin + // names that Generate is permitted to invoke. Any process plugin declared + // in the configuration whose name is not in this list causes Generate to + // fail before parsing or codegen runs. Process plugins execute arbitrary + // local commands; the "Insecure" prefix mirrors + // crypto/tls.Config.InsecureSkipVerify as a reminder that callers must + // consciously trust each plugin name they pass here. + InsecureProcessPluginNames []string +} + +// GenerateResult is the outcome of a Generate call. +type GenerateResult struct { + // Files maps absolute output paths to generated file contents. + Files map[string]string + + // Errors collects any errors encountered. A non-empty Errors slice means + // generation did not fully succeed. + Errors []error +} + +// Generate parses the sqlc configuration referenced by opts and runs every +// configured codegen target. +func Generate(ctx context.Context, opts GenerateOptions) GenerateResult { + stderr := opts.Stderr + if stderr == nil { + stderr = io.Discard + } + + res := GenerateResult{Files: map[string]string{}} + + if opts.Config == nil { + err := errors.New("GenerateOptions.Config is required") + fmt.Fprintln(stderr, err) + res.Errors = append(res.Errors, err) + return res + } + + conf, err := config.ParseConfig(opts.Config) + if err != nil { + switch err { + case config.ErrMissingVersion: + fmt.Fprint(stderr, errMessageNoVersion) + case config.ErrUnknownVersion: + fmt.Fprint(stderr, errMessageUnknownVersion) + case config.ErrNoPackages: + fmt.Fprint(stderr, errMessageNoPackages) + } + fmt.Fprintf(stderr, "error parsing config: %s\n", err) + res.Errors = append(res.Errors, err) + return res + } + + if err := config.Validate(&conf); err != nil { + fmt.Fprintf(stderr, "error validating config: %s\n", err) + res.Errors = append(res.Errors, err) + return res + } + + for _, plug := range conf.Plugins { + if plug.Process == nil { + continue + } + if !slices.Contains(opts.InsecureProcessPluginNames, plug.Name) { + err := fmt.Errorf("process plugin %q is not in InsecureProcessPluginNames; refusing to run", plug.Name) + fmt.Fprintf(stderr, "error validating config: %s\n", err) + res.Errors = append(res.Errors, err) + return res + } + } + + g := &generator{output: map[string]string{}} + + if err := processQuerySets(ctx, g, &conf, stderr); err != nil { + res.Errors = append(res.Errors, err) + return res + } + + res.Files = g.output + + if opts.Write { + if err := writeFiles(ctx, res.Files, stderr); err != nil { + res.Errors = append(res.Errors, err) + } + } + + if opts.Diff { + if err := diffFiles(ctx, res.Files, stderr); err != nil { + res.Errors = append(res.Errors, err) + } + } + + return res +} + +const errMessageNoVersion = `The configuration must have a version number. +Set the version to 1 or 2 at the top of the config: + +{ + "version": "1" + ... +} +` + +const errMessageUnknownVersion = `The configuration has an invalid version number. +The supported version can only be "1" or "2". +` + +const errMessageNoPackages = `No packages are configured` + +type generator struct { + m sync.Mutex + output map[string]string +} + +func (g *generator) Pairs(ctx context.Context, conf *config.Config) []outputPair { + var pairs []outputPair + for _, sql := range conf.SQL { + if sql.Gen.Go != nil { + pairs = append(pairs, outputPair{ + SQL: sql, + Gen: config.SQLGen{Go: sql.Gen.Go}, + }) + } + if sql.Gen.JSON != nil { + pairs = append(pairs, outputPair{ + SQL: sql, + Gen: config.SQLGen{JSON: sql.Gen.JSON}, + }) + } + for i := range sql.Codegen { + pairs = append(pairs, outputPair{ + SQL: sql, + Plugin: &sql.Codegen[i], + }) + } + } + return pairs +} + +func (g *generator) ProcessResult(ctx context.Context, combo config.CombinedSettings, sql outputPair, result *compiler.Result) error { + out, resp, err := codegen(ctx, combo, sql, result) + if err != nil { + return err + } + files := map[string]string{} + for _, file := range resp.Files { + files[file.Name] = string(file.Contents) + } + g.m.Lock() + defer g.m.Unlock() + + absout, err := filepath.Abs(out) + if err != nil { + return err + } + + for n, source := range files { + filename, err := filepath.Abs(filepath.Join(out, n)) + if err != nil { + return err + } + if strings.Contains(filename, "..") { + return fmt.Errorf("invalid file output path: %s", filename) + } + if !strings.HasPrefix(filename, absout) { + return fmt.Errorf("invalid file output path: %s", filename) + } + g.output[filename] = source + } + return nil +} diff --git a/internal/api/parse.go b/internal/api/parse.go new file mode 100644 index 0000000000..cb4406b77e --- /dev/null +++ b/internal/api/parse.go @@ -0,0 +1,70 @@ +package api + +import ( + "context" + "fmt" + "io" + "os" + "path/filepath" + "runtime/trace" + + "github.com/sqlc-dev/sqlc/internal/compiler" + "github.com/sqlc-dev/sqlc/internal/config" + "github.com/sqlc-dev/sqlc/internal/debug" + "github.com/sqlc-dev/sqlc/internal/multierr" + "github.com/sqlc-dev/sqlc/internal/opts" +) + +func printFileErr(stderr io.Writer, fileErr *multierr.FileError) { + wd, err := os.Getwd() + if err != nil { + wd = "" + } + filename := fileErr.Filename + if wd != "" { + if rel, err := filepath.Rel(wd, fileErr.Filename); err == nil { + filename = rel + } + } + fmt.Fprintf(stderr, "%s:%d:%d: %s\n", filename, fileErr.Line, fileErr.Column, fileErr.Err) +} + +func parse(ctx context.Context, name string, sql config.SQL, combo config.CombinedSettings, parserOpts opts.Parser, stderr io.Writer) (*compiler.Result, bool) { + defer trace.StartRegion(ctx, "parse").End() + c, err := compiler.NewCompiler(sql, combo, parserOpts) + defer func() { + if c != nil { + c.Close(ctx) + } + }() + if err != nil { + fmt.Fprintf(stderr, "error creating compiler: %s\n", err) + return nil, true + } + if err := c.ParseCatalog(sql.Schema); err != nil { + fmt.Fprintf(stderr, "# package %s\n", name) + if parserErr, ok := err.(*multierr.Error); ok { + for _, fileErr := range parserErr.Errs() { + printFileErr(stderr, fileErr) + } + } else { + fmt.Fprintf(stderr, "error parsing schema: %s\n", err) + } + return nil, true + } + if parserOpts.Debug.DumpCatalog { + debug.Dump(c.Catalog()) + } + if err := c.ParseQueries(sql.Queries, parserOpts); err != nil { + fmt.Fprintf(stderr, "# package %s\n", name) + if parserErr, ok := err.(*multierr.Error); ok { + for _, fileErr := range parserErr.Errs() { + printFileErr(stderr, fileErr) + } + } else { + fmt.Fprintf(stderr, "error parsing queries: %s\n", err) + } + return nil, true + } + return c.Result(), false +} diff --git a/internal/api/process.go b/internal/api/process.go new file mode 100644 index 0000000000..b4051ddb3e --- /dev/null +++ b/internal/api/process.go @@ -0,0 +1,120 @@ +package api + +import ( + "bytes" + "context" + "fmt" + "io" + "path/filepath" + "runtime" + "runtime/trace" + + "golang.org/x/sync/errgroup" + + "github.com/sqlc-dev/sqlc/internal/compiler" + "github.com/sqlc-dev/sqlc/internal/config" + "github.com/sqlc-dev/sqlc/internal/debug" + "github.com/sqlc-dev/sqlc/internal/opts" +) + +type outputPair struct { + Gen config.SQLGen + Plugin *config.Codegen + + config.SQL +} + +type resultProcessor interface { + Pairs(context.Context, *config.Config) []outputPair + ProcessResult(context.Context, config.CombinedSettings, outputPair, *compiler.Result) error +} + +func processQuerySets(ctx context.Context, rp resultProcessor, conf *config.Config, stderr io.Writer) error { + errored := false + + pairs := rp.Pairs(ctx, conf) + grp, gctx := errgroup.WithContext(ctx) + grp.SetLimit(runtime.GOMAXPROCS(0)) + + stderrs := make([]bytes.Buffer, len(pairs)) + + for i, pair := range pairs { + sql := pair + errout := &stderrs[i] + + grp.Go(func() error { + combo := config.Combine(*conf, sql.SQL) + if sql.Plugin != nil { + combo.Codegen = *sql.Plugin + } + + absSchema := make([]string, 0, len(sql.Schema)) + for _, s := range sql.Schema { + abs, err := filepath.Abs(s) + if err != nil { + fmt.Fprintf(errout, "resolve schema path %s: %s\n", s, err) + errored = true + return nil + } + absSchema = append(absSchema, abs) + } + sql.Schema = absSchema + + absQueries := make([]string, 0, len(sql.Queries)) + for _, q := range sql.Queries { + abs, err := filepath.Abs(q) + if err != nil { + fmt.Fprintf(errout, "resolve query path %s: %s\n", q, err) + errored = true + return nil + } + absQueries = append(absQueries, abs) + } + sql.Queries = absQueries + + var name, lang string + parseOpts := opts.Parser{ + Debug: debug.Debug, + } + + switch { + case sql.Gen.Go != nil: + name = combo.Go.Package + lang = "golang" + + case sql.Plugin != nil: + lang = fmt.Sprintf("process:%s", sql.Plugin.Plugin) + name = sql.Plugin.Plugin + } + + packageRegion := trace.StartRegion(gctx, "package") + trace.Logf(gctx, "", "name=%s plugin=%s", name, lang) + + result, failed := parse(gctx, name, sql.SQL, combo, parseOpts, errout) + if failed { + packageRegion.End() + errored = true + return nil + } + if err := rp.ProcessResult(gctx, combo, sql, result); err != nil { + fmt.Fprintf(errout, "# package %s\n", name) + fmt.Fprintf(errout, "error generating code: %s\n", err) + errored = true + } + packageRegion.End() + return nil + }) + } + if err := grp.Wait(); err != nil { + return err + } + if errored { + for i := range stderrs { + if _, err := io.Copy(stderr, &stderrs[i]); err != nil { + return err + } + } + return fmt.Errorf("errored") + } + return nil +} diff --git a/internal/api/shim.go b/internal/api/shim.go new file mode 100644 index 0000000000..9638d1f126 --- /dev/null +++ b/internal/api/shim.go @@ -0,0 +1,233 @@ +package api + +import ( + "github.com/sqlc-dev/sqlc/internal/compiler" + "github.com/sqlc-dev/sqlc/internal/config" + "github.com/sqlc-dev/sqlc/internal/config/convert" + "github.com/sqlc-dev/sqlc/internal/info" + "github.com/sqlc-dev/sqlc/internal/plugin" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" +) + +func pluginSettings(r *compiler.Result, cs config.CombinedSettings) *plugin.Settings { + return &plugin.Settings{ + Version: cs.Global.Version, + Engine: string(cs.Package.Engine), + Schema: []string(cs.Package.Schema), + Queries: []string(cs.Package.Queries), + Codegen: pluginCodegen(cs, cs.Codegen), + } +} + +func pluginCodegen(cs config.CombinedSettings, s config.Codegen) *plugin.Codegen { + opts, err := convert.YAMLtoJSON(s.Options) + if err != nil { + panic(err) + } + cg := &plugin.Codegen{ + Out: s.Out, + Plugin: s.Plugin, + Options: opts, + } + for _, p := range cs.Global.Plugins { + if p.Name == s.Plugin { + cg.Env = p.Env + cg.Process = pluginProcess(p) + cg.Wasm = pluginWASM(p) + return cg + } + } + return cg +} + +func pluginProcess(p config.Plugin) *plugin.Codegen_Process { + if p.Process != nil { + return &plugin.Codegen_Process{ + Cmd: p.Process.Cmd, + } + } + return nil +} + +func pluginWASM(p config.Plugin) *plugin.Codegen_WASM { + if p.WASM != nil { + return &plugin.Codegen_WASM{ + Url: p.WASM.URL, + Sha256: p.WASM.SHA256, + } + } + return nil +} + +func pluginCatalog(c *catalog.Catalog) *plugin.Catalog { + var schemas []*plugin.Schema + for _, s := range c.Schemas { + var enums []*plugin.Enum + var cts []*plugin.CompositeType + for _, typ := range s.Types { + switch typ := typ.(type) { + case *catalog.Enum: + enums = append(enums, &plugin.Enum{ + Name: typ.Name, + Comment: typ.Comment, + Vals: typ.Vals, + }) + case *catalog.CompositeType: + cts = append(cts, &plugin.CompositeType{ + Name: typ.Name, + Comment: typ.Comment, + }) + } + } + var tables []*plugin.Table + for _, t := range s.Tables { + var columns []*plugin.Column + for _, c := range t.Columns { + l := -1 + if c.Length != nil { + l = *c.Length + } + columns = append(columns, &plugin.Column{ + Name: c.Name, + Type: &plugin.Identifier{ + Catalog: c.Type.Catalog, + Schema: c.Type.Schema, + Name: c.Type.Name, + }, + Comment: c.Comment, + NotNull: c.IsNotNull, + Unsigned: c.IsUnsigned, + IsArray: c.IsArray, + ArrayDims: int32(c.ArrayDims), + Length: int32(l), + Table: &plugin.Identifier{ + Catalog: t.Rel.Catalog, + Schema: t.Rel.Schema, + Name: t.Rel.Name, + }, + }) + } + tables = append(tables, &plugin.Table{ + Rel: &plugin.Identifier{ + Catalog: t.Rel.Catalog, + Schema: t.Rel.Schema, + Name: t.Rel.Name, + }, + Columns: columns, + Comment: t.Comment, + }) + } + schemas = append(schemas, &plugin.Schema{ + Comment: s.Comment, + Name: s.Name, + Tables: tables, + Enums: enums, + CompositeTypes: cts, + }) + } + return &plugin.Catalog{ + Name: c.Name, + DefaultSchema: c.DefaultSchema, + Comment: c.Comment, + Schemas: schemas, + } +} + +func pluginQueries(r *compiler.Result) []*plugin.Query { + var out []*plugin.Query + for _, q := range r.Queries { + var params []*plugin.Parameter + var columns []*plugin.Column + for _, c := range q.Columns { + columns = append(columns, pluginQueryColumn(c)) + } + for _, p := range q.Params { + params = append(params, pluginQueryParam(p)) + } + var iit *plugin.Identifier + if q.InsertIntoTable != nil { + iit = &plugin.Identifier{ + Catalog: q.InsertIntoTable.Catalog, + Schema: q.InsertIntoTable.Schema, + Name: q.InsertIntoTable.Name, + } + } + out = append(out, &plugin.Query{ + Name: q.Metadata.Name, + Cmd: q.Metadata.Cmd, + Text: q.SQL, + Comments: q.Metadata.Comments, + Columns: columns, + Params: params, + Filename: q.Metadata.Filename, + InsertIntoTable: iit, + }) + } + return out +} + +func pluginQueryColumn(c *compiler.Column) *plugin.Column { + l := -1 + if c.Length != nil { + l = *c.Length + } + out := &plugin.Column{ + Name: c.Name, + OriginalName: c.OriginalName, + Comment: c.Comment, + NotNull: c.NotNull, + Unsigned: c.Unsigned, + IsArray: c.IsArray, + ArrayDims: int32(c.ArrayDims), + Length: int32(l), + IsNamedParam: c.IsNamedParam, + IsFuncCall: c.IsFuncCall, + IsSqlcSlice: c.IsSqlcSlice, + } + + if c.Type != nil { + out.Type = &plugin.Identifier{ + Catalog: c.Type.Catalog, + Schema: c.Type.Schema, + Name: c.Type.Name, + } + } else { + out.Type = &plugin.Identifier{ + Name: c.DataType, + } + } + + if c.Table != nil { + out.Table = &plugin.Identifier{ + Catalog: c.Table.Catalog, + Schema: c.Table.Schema, + Name: c.Table.Name, + } + } + + if c.EmbedTable != nil { + out.EmbedTable = &plugin.Identifier{ + Catalog: c.EmbedTable.Catalog, + Schema: c.EmbedTable.Schema, + Name: c.EmbedTable.Name, + } + } + + return out +} + +func pluginQueryParam(p compiler.Parameter) *plugin.Parameter { + return &plugin.Parameter{ + Number: int32(p.Number), + Column: pluginQueryColumn(p.Column), + } +} + +func codeGenRequest(r *compiler.Result, settings config.CombinedSettings) *plugin.GenerateRequest { + return &plugin.GenerateRequest{ + Settings: pluginSettings(r, settings), + Catalog: pluginCatalog(r.Catalog), + Queries: pluginQueries(r), + SqlcVersion: info.Version, + } +} diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go index f9c09dfe06..0f37288419 100644 --- a/internal/cmd/cmd.go +++ b/internal/cmd/cmd.go @@ -1,7 +1,6 @@ package cmd import ( - "bufio" "bytes" "context" "errors" @@ -12,11 +11,11 @@ import ( "path/filepath" "runtime/trace" - "github.com/cubicdaiya/gonp" "github.com/spf13/cobra" "github.com/spf13/pflag" "gopkg.in/yaml.v3" + "github.com/sqlc-dev/sqlc/internal/api" "github.com/sqlc-dev/sqlc/internal/config" "github.com/sqlc-dev/sqlc/internal/debug" "github.com/sqlc-dev/sqlc/internal/info" @@ -184,6 +183,54 @@ func getConfigPath(stderr io.Writer, f *pflag.Flag) (string, string) { } } +// loadConfig opens the sqlc config and reads it into memory. It also chdirs +// the process to the config's directory so that relative paths declared in the +// config resolve correctly when api.Generate is called. Returns the config +// bytes and the list of process plugin names declared in the config (used to +// populate api.GenerateOptions.InsecureProcessPluginNames). +func loadConfig(stderr io.Writer, dir, name string) ([]byte, []string) { + configPath, _, err := readConfig(stderr, dir, name) + if err != nil { + os.Exit(1) + } + configPath, err = filepath.Abs(configPath) + if err != nil { + fmt.Fprintf(stderr, "error resolving config path: %s\n", err) + os.Exit(1) + } + data, err := os.ReadFile(configPath) + if err != nil { + fmt.Fprintf(stderr, "error reading %s: %s\n", configPath, err) + os.Exit(1) + } + conf, err := config.ParseConfig(bytes.NewReader(data)) + if err != nil { + fmt.Fprintf(stderr, "error parsing %s: %s\n", configPath, err) + os.Exit(1) + } + if err := os.Chdir(filepath.Dir(configPath)); err != nil { + fmt.Fprintf(stderr, "error changing directory: %s\n", err) + os.Exit(1) + } + var names []string + for _, p := range conf.Plugins { + if p.Process != nil { + names = append(names, p.Name) + } + } + return data, names +} + +// allowedProcessPluginNames returns the names that should populate +// api.GenerateOptions.InsecureProcessPluginNames. SQLCDEBUG=processplugins=0 +// disables every process plugin by returning nil. +func allowedProcessPluginNames(env Env, declared []string) []string { + if !env.Debug.ProcessPlugins { + return nil + } + return declared +} + var genCmd = &cobra.Command{ Use: "generate", Short: "Generate source code from SQL", @@ -191,21 +238,17 @@ var genCmd = &cobra.Command{ defer trace.StartRegion(cmd.Context(), "generate").End() stderr := cmd.ErrOrStderr() dir, name := getConfigPath(stderr, cmd.Flag("file")) - output, err := Generate(cmd.Context(), dir, name, &Options{ - Env: ParseEnv(cmd), - Stderr: stderr, + env := ParseEnv(cmd) + data, declared := loadConfig(stderr, dir, name) + res := api.Generate(cmd.Context(), api.GenerateOptions{ + Config: bytes.NewReader(data), + Stderr: stderr, + Write: true, + InsecureProcessPluginNames: allowedProcessPluginNames(env, declared), }) - if err != nil { + if len(res.Errors) > 0 { os.Exit(1) } - defer trace.StartRegion(cmd.Context(), "writefiles").End() - for filename, source := range output { - os.MkdirAll(filepath.Dir(filename), 0755) - if err := os.WriteFile(filename, []byte(source), 0644); err != nil { - fmt.Fprintf(stderr, "%s: %s\n", filename, err) - return err - } - } return nil }, } @@ -217,46 +260,20 @@ var checkCmd = &cobra.Command{ defer trace.StartRegion(cmd.Context(), "compile").End() stderr := cmd.ErrOrStderr() dir, name := getConfigPath(stderr, cmd.Flag("file")) - _, err := Generate(cmd.Context(), dir, name, &Options{ - Env: ParseEnv(cmd), - Stderr: stderr, + env := ParseEnv(cmd) + data, declared := loadConfig(stderr, dir, name) + res := api.Generate(cmd.Context(), api.GenerateOptions{ + Config: bytes.NewReader(data), + Stderr: stderr, + InsecureProcessPluginNames: allowedProcessPluginNames(env, declared), }) - if err != nil { + if len(res.Errors) > 0 { os.Exit(1) } return nil }, } -func getLines(f []byte) []string { - fp := bytes.NewReader(f) - scanner := bufio.NewScanner(fp) - lines := make([]string, 0) - for scanner.Scan() { - lines = append(lines, scanner.Text()) - } - return lines -} - -func filterHunks[T gonp.Elem](uniHunks []gonp.UniHunk[T]) []gonp.UniHunk[T] { - var out []gonp.UniHunk[T] - for i, uniHunk := range uniHunks { - var changed bool - for _, e := range uniHunk.GetChanges() { - switch e.GetType() { - case gonp.SesDelete: - changed = true - case gonp.SesAdd: - changed = true - } - } - if changed { - out = append(out, uniHunks[i]) - } - } - return out -} - var diffCmd = &cobra.Command{ Use: "diff", Short: "Compare the generated files to the existing files", @@ -264,11 +281,15 @@ var diffCmd = &cobra.Command{ defer trace.StartRegion(cmd.Context(), "diff").End() stderr := cmd.ErrOrStderr() dir, name := getConfigPath(stderr, cmd.Flag("file")) - opts := &Options{ - Env: ParseEnv(cmd), - Stderr: stderr, - } - if err := Diff(cmd.Context(), dir, name, opts); err != nil { + env := ParseEnv(cmd) + data, declared := loadConfig(stderr, dir, name) + res := api.Generate(cmd.Context(), api.GenerateOptions{ + Config: bytes.NewReader(data), + Stderr: stderr, + Diff: true, + InsecureProcessPluginNames: allowedProcessPluginNames(env, declared), + }) + if len(res.Errors) > 0 { os.Exit(1) } return nil diff --git a/internal/cmd/diff.go b/internal/cmd/diff.go deleted file mode 100644 index 8998971a37..0000000000 --- a/internal/cmd/diff.go +++ /dev/null @@ -1,59 +0,0 @@ -package cmd - -import ( - "context" - "errors" - "fmt" - "os" - "runtime/trace" - "sort" - "strings" - - "github.com/cubicdaiya/gonp" -) - -func Diff(ctx context.Context, dir, name string, opts *Options) error { - stderr := opts.Stderr - output, err := Generate(ctx, dir, name, opts) - if err != nil { - return err - } - defer trace.StartRegion(ctx, "checkfiles").End() - var errored bool - - keys := make([]string, 0, len(output)) - for k, _ := range output { - kk := k - keys = append(keys, kk) - } - sort.Strings(keys) - - for _, filename := range keys { - source := output[filename] - if _, err := os.Stat(filename); errors.Is(err, os.ErrNotExist) { - errored = true - // stdout message - continue - } - existing, err := os.ReadFile(filename) - if err != nil { - errored = true - fmt.Fprintf(stderr, "%s: %s\n", filename, err) - continue - } - diff := gonp.New(getLines(existing), getLines([]byte(source))) - diff.Compose() - uniHunks := filterHunks(diff.UnifiedHunks()) - - if len(uniHunks) > 0 { - errored = true - fmt.Fprintf(stderr, "--- a%s\n", strings.TrimPrefix(filename, dir)) - fmt.Fprintf(stderr, "+++ b%s\n", strings.TrimPrefix(filename, dir)) - diff.FprintUniHunks(stderr, uniHunks) - } - } - if errored { - return errors.New("diff found") - } - return nil -} diff --git a/internal/cmd/generate.go b/internal/cmd/generate.go index ca3ee680b5..1785a40718 100644 --- a/internal/cmd/generate.go +++ b/internal/cmd/generate.go @@ -2,30 +2,18 @@ package cmd import ( "context" - "encoding/json" "errors" "fmt" "io" "os" "path/filepath" "runtime/trace" - "strings" - "sync" - "google.golang.org/grpc" - - "github.com/sqlc-dev/sqlc/internal/codegen/golang" - genjson "github.com/sqlc-dev/sqlc/internal/codegen/json" "github.com/sqlc-dev/sqlc/internal/compiler" "github.com/sqlc-dev/sqlc/internal/config" - "github.com/sqlc-dev/sqlc/internal/config/convert" "github.com/sqlc-dev/sqlc/internal/debug" - "github.com/sqlc-dev/sqlc/internal/ext" - "github.com/sqlc-dev/sqlc/internal/ext/process" - "github.com/sqlc-dev/sqlc/internal/ext/wasm" "github.com/sqlc-dev/sqlc/internal/multierr" "github.com/sqlc-dev/sqlc/internal/opts" - "github.com/sqlc-dev/sqlc/internal/plugin" ) const errMessageNoVersion = `The configuration file must have a version number. @@ -51,15 +39,6 @@ func printFileErr(stderr io.Writer, dir string, fileErr *multierr.FileError) { fmt.Fprintf(stderr, "%s:%d:%d: %s\n", filename, fileErr.Line, fileErr.Column, fileErr.Err) } -func findPlugin(conf config.Config, name string) (*config.Plugin, error) { - for _, plug := range conf.Plugins { - if plug.Name == name { - return &plug, nil - } - } - return nil, fmt.Errorf("plugin not found") -} - func readConfig(stderr io.Writer, dir, filename string) (string, *config.Config, error) { configPath := "" if filename != "" { @@ -127,100 +106,6 @@ func readConfig(stderr io.Writer, dir, filename string) (string, *config.Config, return configPath, &conf, nil } -func Generate(ctx context.Context, dir, filename string, o *Options) (map[string]string, error) { - e := o.Env - stderr := o.Stderr - - configPath, conf, err := o.ReadConfig(dir, filename) - if err != nil { - return nil, err - } - - base := filepath.Base(configPath) - if err := config.Validate(conf); err != nil { - fmt.Fprintf(stderr, "error validating %s: %s\n", base, err) - return nil, err - } - - if err := e.Validate(conf); err != nil { - fmt.Fprintf(stderr, "error validating %s: %s\n", base, err) - return nil, err - } - - g := &generator{ - dir: dir, - output: map[string]string{}, - } - - if err := processQuerySets(ctx, g, conf, dir, o); err != nil { - return nil, err - } - - return g.output, nil -} - -type generator struct { - m sync.Mutex - dir string - output map[string]string -} - -func (g *generator) Pairs(ctx context.Context, conf *config.Config) []OutputPair { - var pairs []OutputPair - for _, sql := range conf.SQL { - if sql.Gen.Go != nil { - pairs = append(pairs, OutputPair{ - SQL: sql, - Gen: config.SQLGen{Go: sql.Gen.Go}, - }) - } - if sql.Gen.JSON != nil { - pairs = append(pairs, OutputPair{ - SQL: sql, - Gen: config.SQLGen{JSON: sql.Gen.JSON}, - }) - } - for i := range sql.Codegen { - pairs = append(pairs, OutputPair{ - SQL: sql, - Plugin: &sql.Codegen[i], - }) - } - } - return pairs -} - -func (g *generator) ProcessResult(ctx context.Context, combo config.CombinedSettings, sql OutputPair, result *compiler.Result) error { - out, resp, err := codegen(ctx, combo, sql, result) - if err != nil { - return err - } - files := map[string]string{} - for _, file := range resp.Files { - files[file.Name] = string(file.Contents) - } - g.m.Lock() - - // out is specified by the user, not a plugin - absout := filepath.Join(g.dir, out) - - for n, source := range files { - filename := filepath.Join(g.dir, out, n) - // filepath.Join calls filepath.Clean which should remove all "..", but - // double check to make sure - if strings.Contains(filename, "..") { - return fmt.Errorf("invalid file output path: %s", filename) - } - // The output file must be contained inside the output directory - if !strings.HasPrefix(filename, absout) { - return fmt.Errorf("invalid file output path: %s", filename) - } - g.output[filename] = source - } - g.m.Unlock() - return nil -} - func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.CombinedSettings, parserOpts opts.Parser, stderr io.Writer) (*compiler.Result, bool) { defer trace.StartRegion(ctx, "parse").End() c, err := compiler.NewCompiler(sql, combo, parserOpts) @@ -260,82 +145,3 @@ func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.C } return c.Result(), false } - -func codegen(ctx context.Context, combo config.CombinedSettings, sql OutputPair, result *compiler.Result) (string, *plugin.GenerateResponse, error) { - defer trace.StartRegion(ctx, "codegen").End() - req := codeGenRequest(result, combo) - var handler grpc.ClientConnInterface - var out string - switch { - case sql.Plugin != nil: - out = sql.Plugin.Out - plug, err := findPlugin(combo.Global, sql.Plugin.Plugin) - if err != nil { - return "", nil, fmt.Errorf("plugin not found: %s", err) - } - - switch { - case plug.Process != nil: - handler = &process.Runner{ - Cmd: plug.Process.Cmd, - Env: plug.Env, - Format: plug.Process.Format, - } - case plug.WASM != nil: - handler = &wasm.Runner{ - URL: plug.WASM.URL, - SHA256: plug.WASM.SHA256, - Env: plug.Env, - } - default: - return "", nil, fmt.Errorf("unsupported plugin type") - } - - opts, err := convert.YAMLtoJSON(sql.Plugin.Options) - if err != nil { - return "", nil, fmt.Errorf("invalid plugin options: %w", err) - } - req.PluginOptions = opts - - global, found := combo.Global.Options[plug.Name] - if found { - opts, err := convert.YAMLtoJSON(global) - if err != nil { - return "", nil, fmt.Errorf("invalid global options: %w", err) - } - req.GlobalOptions = opts - } - - case sql.Gen.Go != nil: - out = combo.Go.Out - handler = ext.HandleFunc(golang.Generate) - opts, err := json.Marshal(sql.Gen.Go) - if err != nil { - return "", nil, fmt.Errorf("opts marshal failed: %w", err) - } - req.PluginOptions = opts - - if combo.Global.Overrides.Go != nil { - opts, err := json.Marshal(combo.Global.Overrides.Go) - if err != nil { - return "", nil, fmt.Errorf("opts marshal failed: %w", err) - } - req.GlobalOptions = opts - } - - case sql.Gen.JSON != nil: - out = combo.JSON.Out - handler = ext.HandleFunc(genjson.Generate) - opts, err := json.Marshal(sql.Gen.JSON) - if err != nil { - return "", nil, fmt.Errorf("opts marshal failed: %w", err) - } - req.PluginOptions = opts - - default: - return "", nil, fmt.Errorf("missing language backend") - } - client := plugin.NewCodegenServiceClient(handler) - resp, err := client.Generate(ctx, req) - return out, resp, err -} diff --git a/internal/cmd/options.go b/internal/cmd/options.go index 02d3614f4e..7a1fab33f0 100644 --- a/internal/cmd/options.go +++ b/internal/cmd/options.go @@ -12,18 +12,8 @@ type Options struct { // TODO: Move these to a command-specific struct Tags []string Against string - - // Testing only - MutateConfig func(*config.Config) } func (o *Options) ReadConfig(dir, filename string) (string, *config.Config, error) { - path, conf, err := readConfig(o.Stderr, dir, filename) - if err != nil { - return path, conf, err - } - if o.MutateConfig != nil { - o.MutateConfig(conf) - } - return path, conf, nil + return readConfig(o.Stderr, dir, filename) } diff --git a/internal/cmd/process.go b/internal/cmd/process.go index 5003d113b8..c5e3286299 100644 --- a/internal/cmd/process.go +++ b/internal/cmd/process.go @@ -74,15 +74,21 @@ func processQuerySets(ctx context.Context, rp ResultProcessor, conf *config.Conf } // TODO: This feels like a hack that will bite us later + joinDir := func(p string) string { + if filepath.IsAbs(p) { + return p + } + return filepath.Join(dir, p) + } joined := make([]string, 0, len(sql.Schema)) for _, s := range sql.Schema { - joined = append(joined, filepath.Join(dir, s)) + joined = append(joined, joinDir(s)) } sql.Schema = joined joined = make([]string, 0, len(sql.Queries)) for _, q := range sql.Queries { - joined = append(joined, filepath.Join(dir, q)) + joined = append(joined, joinDir(q)) } sql.Queries = joined diff --git a/internal/cmd/vet.go b/internal/cmd/vet.go index 4dbd3c3b7b..3d70d9a301 100644 --- a/internal/cmd/vet.go +++ b/internal/cmd/vet.go @@ -467,15 +467,21 @@ func (c *checker) checkSQL(ctx context.Context, s config.SQL) error { combo := config.Combine(*c.Conf, s) // TODO: This feels like a hack that will bite us later + joinDir := func(p string) string { + if filepath.IsAbs(p) { + return p + } + return filepath.Join(c.Dir, p) + } joined := make([]string, 0, len(s.Schema)) - for _, s := range s.Schema { - joined = append(joined, filepath.Join(c.Dir, s)) + for _, p := range s.Schema { + joined = append(joined, joinDir(p)) } s.Schema = joined joined = make([]string, 0, len(s.Queries)) for _, q := range s.Queries { - joined = append(joined, filepath.Join(c.Dir, q)) + joined = append(joined, joinDir(q)) } s.Queries = joined diff --git a/internal/config/config.go b/internal/config/config.go index d3e610ef05..19bcda754a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -189,6 +189,30 @@ func (a *AnalyzerDatabase) UnmarshalYAML(unmarshal func(interface{}) error) erro return errors.New("analyzer.database must be true, false, or \"only\"") } +// MarshalJSON serialises the value back to its source form so configs +// produced via UnmarshalJSON round-trip through json.Marshal. +func (a AnalyzerDatabase) MarshalJSON() ([]byte, error) { + if a.isOnly { + return json.Marshal("only") + } + if a.value == nil { + return json.Marshal(true) + } + return json.Marshal(*a.value) +} + +// MarshalYAML serialises the value back to its source form so configs +// produced via UnmarshalYAML round-trip through yaml.Marshal. +func (a AnalyzerDatabase) MarshalYAML() (interface{}, error) { + if a.isOnly { + return "only", nil + } + if a.value == nil { + return true, nil + } + return *a.value, nil +} + type Analyzer struct { Database AnalyzerDatabase `json:"database" yaml:"database"` } diff --git a/internal/endtoend/endtoend_test.go b/internal/endtoend/endtoend_test.go index 91e44ff7f0..9bcb36d482 100644 --- a/internal/endtoend/endtoend_test.go +++ b/internal/endtoend/endtoend_test.go @@ -3,6 +3,8 @@ package main import ( "bytes" "context" + "fmt" + "io" "os" osexec "os/exec" "path/filepath" @@ -13,7 +15,9 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "gopkg.in/yaml.v3" + "github.com/sqlc-dev/sqlc/internal/api" "github.com/sqlc-dev/sqlc/internal/cmd" "github.com/sqlc-dev/sqlc/internal/config" "github.com/sqlc-dev/sqlc/internal/opts" @@ -58,15 +62,14 @@ func TestExamples(t *testing.T) { t.Parallel() path := filepath.Join(examples, tc) var stderr bytes.Buffer - opts := &cmd.Options{ - Env: cmd.Env{}, + res := api.Generate(ctx, api.GenerateOptions{ + Config: openConfigReader(t, path), Stderr: &stderr, - } - output, err := cmd.Generate(ctx, path, "", opts) - if err != nil { + }) + if len(res.Errors) > 0 { t.Fatalf("sqlc generate failed: %s", stderr.String()) } - cmpDirectory(t, path, output) + cmpDirectory(t, path, res.Files) }) } } @@ -88,20 +91,137 @@ func BenchmarkExamples(b *testing.B) { tc := replay.Name() b.Run(tc, func(b *testing.B) { path := filepath.Join(examples, tc) + cfg := openConfigBytes(b, path) for i := 0; i < b.N; i++ { var stderr bytes.Buffer - opts := &cmd.Options{ - Env: cmd.Env{}, + api.Generate(ctx, api.GenerateOptions{ + Config: bytes.NewReader(cfg), Stderr: &stderr, - } - cmd.Generate(ctx, path, "", opts) + }) } }) } } +// openConfigReader reads the sqlc config in dir, rewrites every relative +// schema/queries/output path to an absolute one (so api.Generate doesn't have +// to know the config's directory), and returns the result as an io.Reader. +func openConfigReader(t testing.TB, dir string) io.Reader { + return bytes.NewReader(openConfigBytes(t, dir)) +} + +func openConfigBytes(t testing.TB, dir string) []byte { + t.Helper() + data, _ := mutatedConfigBytes(t, dir, nil) + return data +} + +// mutatedConfigBytes parses the sqlc config in dir, applies mutate (when +// non-nil) to the in-memory Config, makes every path absolute relative to dir, +// and re-encodes as YAML. Parsing v1 configs converts them to a v2-shaped +// Config; we force version "2" so the result can be parsed back by api.Generate. +// +// When mutate is non-nil, the encoded bytes are also written to a temp file +// alongside the original (cleaned up at test end) and the filename is returned +// so callers like cmd.Vet that still take a config-file path can use it. +func mutatedConfigBytes(t testing.TB, dir string, mutate func(*config.Config)) ([]byte, string) { + t.Helper() + original, conf, err := readSqlcConfig(dir) + if err != nil { + t.Fatalf("read sqlc config from %s: %s", dir, err) + } + + conf.Version = "2" + absolutizePaths(conf, dir) + if mutate != nil { + mutate(conf) + } + + var buf bytes.Buffer + enc := yaml.NewEncoder(&buf) + if err := enc.Encode(conf); err != nil { + t.Fatalf("encode config: %s", err) + } + if err := enc.Close(); err != nil { + t.Fatalf("close yaml encoder: %s", err) + } + data := buf.Bytes() + + if mutate == nil { + return data, "" + } + + f, err := os.CreateTemp(dir, "sqlc.test-*"+filepath.Ext(original)) + if err != nil { + t.Fatalf("create temp config in %s: %s", dir, err) + } + t.Cleanup(func() { os.Remove(f.Name()) }) + if _, err := f.Write(data); err != nil { + f.Close() + t.Fatalf("write temp config %s: %s", f.Name(), err) + } + if err := f.Close(); err != nil { + t.Fatalf("close temp config %s: %s", f.Name(), err) + } + return data, filepath.Base(f.Name()) +} + +func absolutizePaths(conf *config.Config, dir string) { + abs := func(p string) string { + if p == "" || filepath.IsAbs(p) { + return p + } + return filepath.Join(dir, p) + } + for i := range conf.SQL { + s := &conf.SQL[i] + for j, p := range s.Schema { + s.Schema[j] = abs(p) + } + for j, p := range s.Queries { + s.Queries[j] = abs(p) + } + if s.Gen.Go != nil { + s.Gen.Go.Out = abs(s.Gen.Go.Out) + } + if s.Gen.JSON != nil { + s.Gen.JSON.Out = abs(s.Gen.JSON.Out) + } + for j := range s.Codegen { + s.Codegen[j].Out = abs(s.Codegen[j].Out) + } + } +} + +func readSqlcConfig(dir string) (string, *config.Config, error) { + for _, name := range []string{"sqlc.yaml", "sqlc.yml", "sqlc.json"} { + path := filepath.Join(dir, name) + f, err := os.Open(path) + if err != nil { + if os.IsNotExist(err) { + continue + } + return path, nil, err + } + defer f.Close() + conf, err := config.ParseConfig(f) + if err != nil { + return path, nil, fmt.Errorf("parse %s: %w", path, err) + } + return path, &conf, nil + } + return "", nil, fmt.Errorf("no sqlc config found in %s", dir) +} + +// configRef is the result of preparing a config for a single TestReplay case. +// reader is for api.Generate; file is for cmd.Vet which still takes a path. +type configRef struct { + reader io.Reader + file string +} + type textContext struct { - Mutate func(*testing.T, string) func(*config.Config) + Config func(*testing.T, string) configRef Enabled func() bool } @@ -173,45 +293,28 @@ func TestReplay(t *testing.T) { contexts := map[string]textContext{ "base": { - Mutate: func(t *testing.T, path string) func(*config.Config) { return func(c *config.Config) {} }, + Config: func(t *testing.T, dir string) configRef { + data, _ := mutatedConfigBytes(t, dir, nil) + return configRef{reader: bytes.NewReader(data)} + }, Enabled: func() bool { return true }, }, "managed-db": { - Mutate: func(t *testing.T, path string) func(*config.Config) { - return func(c *config.Config) { + Config: func(t *testing.T, dir string) configRef { + data, file := mutatedConfigBytes(t, dir, func(c *config.Config) { // Add all servers - tests will fail if database isn't available c.Servers = []config.Server{ - { - Name: "postgres", - Engine: config.EnginePostgreSQL, - URI: postgresURI, - }, - { - Name: "mysql", - Engine: config.EngineMySQL, - URI: mysqlURI, - }, + {Name: "postgres", Engine: config.EnginePostgreSQL, URI: postgresURI}, + {Name: "mysql", Engine: config.EngineMySQL, URI: mysqlURI}, } - for i := range c.SQL { switch c.SQL[i].Engine { - case config.EnginePostgreSQL: - c.SQL[i].Database = &config.Database{ - Managed: true, - } - case config.EngineMySQL: - c.SQL[i].Database = &config.Database{ - Managed: true, - } - case config.EngineSQLite: - c.SQL[i].Database = &config.Database{ - Managed: true, - } - default: - // pass + case config.EnginePostgreSQL, config.EngineMySQL, config.EngineSQLite: + c.SQL[i].Database = &config.Database{Managed: true} } } - } + }) + return configRef{reader: bytes.NewReader(data), file: file} }, Enabled: func() bool { // Enabled if at least one database URI is available @@ -261,25 +364,39 @@ func TestReplay(t *testing.T) { } } - opts := cmd.Options{ + cfg := testctx.Config(t, path) + cmdOpts := cmd.Options{ Env: cmd.Env{ Debug: opts.DebugFromString(args.Env["SQLCDEBUG"]), Experiment: opts.ExperimentFromString(args.Env["SQLCEXPERIMENT"]), }, - Stderr: &stderr, - MutateConfig: testctx.Mutate(t, path), + Stderr: &stderr, } switch args.Command { case "diff": - err = cmd.Diff(ctx, path, "", &opts) + res := api.Generate(ctx, api.GenerateOptions{ + Config: cfg.reader, + Stderr: &stderr, + Diff: true, + }) + if len(res.Errors) > 0 { + err = res.Errors[0] + } case "generate": - output, err = cmd.Generate(ctx, path, "", &opts) + res := api.Generate(ctx, api.GenerateOptions{ + Config: cfg.reader, + Stderr: &stderr, + }) + output = res.Files + if len(res.Errors) > 0 { + err = res.Errors[0] + } if err == nil { cmpDirectory(t, path, output) } case "vet": - err = cmd.Vet(ctx, path, "", &opts) + err = cmd.Vet(ctx, path, cfg.file, &cmdOpts) default: t.Fatalf("unknown command") } @@ -323,6 +440,10 @@ func cmpDirectory(t *testing.T, dir string, actual map[string]string) { if filepath.Base(path) == "exec.json" { return nil } + // Mutated configs written by mutatedConfigBytes. + if strings.HasPrefix(filepath.Base(path), "sqlc.test-") { + return nil + } if strings.Contains(path, "/kotlin/build") { return nil } @@ -385,13 +506,13 @@ func BenchmarkReplay(b *testing.B) { tc := replay b.Run(tc, func(b *testing.B) { path, _ := filepath.Abs(tc) + cfg := openConfigBytes(b, path) for i := 0; i < b.N; i++ { var stderr bytes.Buffer - opts := &cmd.Options{ - Env: cmd.Env{}, + api.Generate(ctx, api.GenerateOptions{ + Config: bytes.NewReader(cfg), Stderr: &stderr, - } - cmd.Generate(ctx, path, "", opts) + }) } }) }