Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export Paramgen #131

Merged
merged 8 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions lang/lang.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@

package lang

// Ptr returns a pointer to the value passed in.
func Ptr[T any](t T) *T {
return &t
}

// ValOrZero returns the value of the pointer passed in or the zero value of the
// type if the pointer is nil.
func ValOrZero[T any](t *T) T {
if t == nil {
return Zero[T]()
}
return *t
}

// Zero returns the zero value of the type passed in.
func Zero[T any]() T {
var t T
return t
}
6 changes: 3 additions & 3 deletions paramgen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
"os"
"strings"

"github.com/conduitio/conduit-commons/paramgen/internal"
"github.com/conduitio/conduit-commons/paramgen/paramgen"
)

func main() {
Expand All @@ -32,12 +32,12 @@ func main() {
args := parseFlags()

// parse the sdk parameters
params, pkg, err := internal.ParseParameters(args.path, args.structName)
params, pkg, err := paramgen.ParseParameters(args.path, args.structName)
if err != nil {
log.Fatalf("error: failed to parse parameters: %v", err)
}

code := internal.GenerateCode(params, pkg, args.structName)
code := paramgen.GenerateCode(params, pkg, args.structName)

path := strings.TrimSuffix(args.path, "/") + "/" + args.output
err = os.WriteFile(path, []byte(code), 0o600)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package internal
package paramgen

import (
"os"
Expand All @@ -39,6 +39,10 @@ func TestIntegration(t *testing.T) {
havePath: "./testdata/tags",
structName: "Config",
wantPath: "./testdata/tags/want.go",
}, {
havePath: "./testdata/dependencies",
structName: "Config",
wantPath: "./testdata/dependencies/want.go",
}}

for _, tc := range testCases {
Expand Down
166 changes: 121 additions & 45 deletions paramgen/internal/paramgen.go → paramgen/paramgen/paramgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
// limitations under the License.

//nolint:err113,wrapcheck,staticcheck // we don't care about wrapping errors here, also ignore usage of ast.Package (deprecated)
package internal
package paramgen

import (
"encoding/json"
"fmt"
"go/ast"
"go/parser"
"go/token"
"io"
"io/fs"
"os/exec"
"reflect"
Expand Down Expand Up @@ -115,31 +116,34 @@ func parsePackage(path string) (*ast.Package, error) {
filterTests := func(info fs.FileInfo) bool {
return !strings.HasSuffix(info.Name(), "_test.go")
}
pkgs, err := parser.ParseDir(fset, path, filterTests, parser.ParseComments)
pkgs, err := parser.ParseDir(fset, path, filterTests, parser.ParseComments|parser.SkipObjectResolution)
if err != nil {
return nil, fmt.Errorf("couldn't parse directory %s: %w", path, err)
}
// Make sure they are all in one package.
if len(pkgs) == 0 {
return nil, fmt.Errorf("no source-code package in directory %s", path)
}
// Ignore files with go:build constraint set to "tools" (common pattern in
// Conduit connectors).
for pkgName, pkg := range pkgs {
// Ignore files with go:build constraint set to "tools" (common pattern in
// Conduit connectors).
maps.DeleteFunc(pkg.Files, func(_ string, f *ast.File) bool {
return hasBuildConstraint(f, "tools")
})
if len(pkg.Files) == 0 {
// Remove empty packages or the main package (can't be imported).
if len(pkg.Files) == 0 || pkgName == "main" {
delete(pkgs, pkgName)
}
}
if len(pkgs) > 1 {

// Make sure there is only 1 package.
switch len(pkgs) {
case 0:
return nil, fmt.Errorf("no source-code package in directory %s", path)
case 1:
for _, pkg := range pkgs {
return pkg, nil
}
panic("unreachable")
default:
return nil, fmt.Errorf("multiple packages %v in directory %s", maps.Keys(pkgs), path)
}
for _, v := range pkgs {
return v, nil // return first package
}
panic("unreachable")
}

// hasBuildConstraint is a very naive way to check if a file has a build
Expand Down Expand Up @@ -205,6 +209,10 @@ func (p *parameterParser) Parse(structType *ast.StructType) (map[string]config.P
}

func (p *parameterParser) parseIdent(ident *ast.Ident, field *ast.Field) (params map[string]config.Parameter, err error) {
if field != nil && p.shouldSkipField(field) {
return nil, nil //nolint:nilnil // ignore this validation
}

defer func() {
if err != nil {
err = fmt.Errorf("[parseIdent] %w", err)
Expand Down Expand Up @@ -252,6 +260,10 @@ func (p *parameterParser) parseIdent(ident *ast.Ident, field *ast.Field) (params
}

func (p *parameterParser) parseTypeSpec(ts *ast.TypeSpec, f *ast.Field) (params map[string]config.Parameter, err error) {
if f != nil && p.shouldSkipField(f) {
return nil, nil //nolint:nilnil // ignore this validation
}

defer func() {
if err != nil {
err = fmt.Errorf("[parseTypeSpec] %w", err)
Expand All @@ -267,12 +279,18 @@ func (p *parameterParser) parseTypeSpec(ts *ast.TypeSpec, f *ast.Field) (params
return p.parseIdent(v, f)
case *ast.MapType:
return p.parseMapType(v, f)
case *ast.InterfaceType:
return nil, fmt.Errorf("error parsing type spec for %s.%s.%s: interface types not supported", p.pkg.Name, ts.Name.Name, p.getFieldNameOrUnknown(f))
default:
return nil, fmt.Errorf("unexpected type: %T", ts.Type)
}
}

func (p *parameterParser) parseStructType(st *ast.StructType, f *ast.Field) (params map[string]config.Parameter, err error) {
if f != nil && p.shouldSkipField(f) {
return nil, nil //nolint:nilnil // ignore this validation
}

defer func() {
if err != nil {
err = fmt.Errorf("[parseStructType] %w", err)
Expand Down Expand Up @@ -303,6 +321,10 @@ func (p *parameterParser) parseStructType(st *ast.StructType, f *ast.Field) (par
}

func (p *parameterParser) parseField(f *ast.Field) (params map[string]config.Parameter, err error) {
if f != nil && p.shouldSkipField(f) {
return nil, nil //nolint:nilnil // ignore this validation
}

defer func() {
if err != nil {
err = fmt.Errorf("[parseField] %w", err)
Expand All @@ -313,34 +335,45 @@ func (p *parameterParser) parseField(f *ast.Field) (params map[string]config.Par
return nil, nil //nolint:nilnil // ignore unexported fields
}

switch v := f.Type.(type) {
case *ast.Ident:
// identifier (builtin type or type in same package)
return p.parseIdent(v, f)
case *ast.StructType:
// nested type
return p.parseStructType(v, f)
case *ast.SelectorExpr:
return p.parseSelectorExpr(v, f)
case *ast.MapType:
return p.parseMapType(v, f)
case *ast.ArrayType:
strType := fmt.Sprintf("%s", v.Elt)
if !p.isBuiltinType(strType) && !strings.Contains(strType, "time Duration") {
return nil, fmt.Errorf("unsupported slice type: %s", strType)
}
expr := f.Type
for {
switch v := expr.(type) {
case *ast.StarExpr:
// dereference pointer
expr = v.X
continue
case *ast.Ident:
// identifier (builtin type or type in same package)
return p.parseIdent(v, f)
case *ast.StructType:
// nested type
return p.parseStructType(v, f)
case *ast.SelectorExpr:
return p.parseSelectorExpr(v, f)
case *ast.MapType:
return p.parseMapType(v, f)
case *ast.ArrayType:
strType := fmt.Sprintf("%s", v.Elt)
if !p.isBuiltinType(strType) && !strings.Contains(strType, "time Duration") {
return nil, fmt.Errorf("unsupported slice type: %s", strType)
}

name, param, err := p.parseSingleParameter(f, config.ParameterTypeString)
if err != nil {
return nil, err
name, param, err := p.parseSingleParameter(f, config.ParameterTypeString)
if err != nil {
return nil, err
}
return map[string]config.Parameter{name: param}, nil
default:
return nil, fmt.Errorf("unknown type: %T", f.Type)
}
return map[string]config.Parameter{name: param}, nil
default:
return nil, fmt.Errorf("unknown type: %T", f.Type)
}
}

func (p *parameterParser) parseMapType(mt *ast.MapType, f *ast.Field) (params map[string]config.Parameter, err error) {
if f != nil && p.shouldSkipField(f) {
return nil, nil //nolint:nilnil // ignore this validation
}

if fmt.Sprintf("%s", mt.Key) != "string" {
return nil, fmt.Errorf("unsupported map key type: %s", mt.Key)
}
Expand Down Expand Up @@ -378,6 +411,10 @@ func (p *parameterParser) parseMapType(mt *ast.MapType, f *ast.Field) (params ma
}

func (p *parameterParser) parseSelectorExpr(se *ast.SelectorExpr, f *ast.Field) (params map[string]config.Parameter, err error) {
if f != nil && p.shouldSkipField(f) {
return nil, nil //nolint:nilnil // ignore this validation
}

defer func() {
if err != nil {
err = fmt.Errorf("[parseSelectorExpr] %w", err)
Expand Down Expand Up @@ -428,17 +465,21 @@ func (p *parameterParser) findPackage(importPath string) (*ast.Package, error) {
// first cleanup string
importPath = strings.Trim(importPath, `"`)

if !strings.HasPrefix(importPath, p.mod.Path) {
// we only allow types declared in the same module
return nil, fmt.Errorf("we do not support parameters from package %v (please use builtin types or time.Duration)", importPath)
}

if pkg, ok := p.imports[importPath]; ok {
// it's cached already
return pkg, nil
}

pkgDir := p.mod.Dir + strings.TrimPrefix(importPath, p.mod.Path)
if !strings.HasPrefix(importPath, p.mod.Path) {
// Import path is not part of the module, we need to find the package path
var err error
pkgDir, err = p.packageToPath(importPath)
if err != nil {
return nil, fmt.Errorf("could not get package path for %q: %w", importPath, err)
}
}

pkg, err := parsePackage(pkgDir)
if err != nil {
return nil, fmt.Errorf("could not parse package dir %q: %w", pkgDir, err)
Expand Down Expand Up @@ -514,6 +555,11 @@ func (p *parameterParser) attachPrefix(f *ast.Field, params map[string]config.Pa
return prefixedParams
}

func (p *parameterParser) shouldSkipField(f *ast.Field) bool {
val := p.getTag(f.Tag, tagParamName)
return val == "-"
}

func (p *parameterParser) isBuiltinType(name string) bool {
switch name {
case "string", "bool", "int", "uint", "int8", "uint8", "int16", "uint16", "int32", "uint32", "int64", "uint64",
Expand Down Expand Up @@ -593,7 +639,7 @@ func (p *parameterParser) getParamType(i *ast.Ident) config.ParameterType {
// lowercase letter. If the string starts with multiple uppercase letters, all
// but the last character in the sequence will be converted into lowercase
// letters (e.g. HTTPRequest -> httpRequest).
func (p *parameterParser) formatFieldName(name string) string {
func (*parameterParser) formatFieldName(name string) string {
if name == "" {
return ""
}
Expand All @@ -619,7 +665,7 @@ func (p *parameterParser) formatFieldName(name string) string {
return newName
}

func (p *parameterParser) formatFieldComment(f *ast.Field) string {
func (*parameterParser) formatFieldComment(f *ast.Field) string {
doc := f.Doc
if doc == nil {
// fallback to line comment
Expand All @@ -644,7 +690,7 @@ func (p *parameterParser) formatFieldComment(f *ast.Field) string {
return c
}

func (p *parameterParser) getTag(lit *ast.BasicLit, tag string) string {
func (*parameterParser) getTag(lit *ast.BasicLit, tag string) string {
if lit == nil {
return ""
}
Expand All @@ -671,7 +717,7 @@ func (p *parameterParser) parseValidateTag(tag string) ([]config.Validation, err
return validations, nil
}

func (p *parameterParser) parseValidation(str string) (config.Validation, error) {
func (*parameterParser) parseValidation(str string) (config.Validation, error) {
if str == validationRequired {
return config.ValidationRequired{}, nil
}
Expand Down Expand Up @@ -715,3 +761,33 @@ func (p *parameterParser) parseValidation(str string) (config.Validation, error)
return nil, fmt.Errorf("invalid value for tag validate: %s", str)
}
}

// packageToPath takes a package import path and returns the path to the directory
// of that package.
func (p *parameterParser) packageToPath(pkg string) (string, error) {
cmd := exec.Command("go", "list", "-f", "{{.Dir}}", pkg)
cmd.Dir = p.mod.Dir
stdout, err := cmd.StdoutPipe()
if err != nil {
return "", fmt.Errorf("error piping stdout of go list command: %w", err)
}
stderr, err := cmd.StderrPipe()
if err != nil {
return "", fmt.Errorf("error piping stderr of go list command: %w", err)
}
if err := cmd.Start(); err != nil {
return "", fmt.Errorf("error starting go list command: %w", err)
}
path, err := io.ReadAll(stdout)
if err != nil {
return "", fmt.Errorf("error reading stdout of go list command: %w", err)
}
errMsg, err := io.ReadAll(stderr)
if err != nil {
return "", fmt.Errorf("error reading stderr of go list command: %w", err)
}
if err := cmd.Wait(); err != nil {
return "", fmt.Errorf("error running command %q (error message: %q): %w", cmd.String(), errMsg, err)
}
return strings.TrimRight(string(path), "\n"), nil
}
Loading