From 53f10af5d5c7375d4655a3d6852457ed17ab5cc7 Mon Sep 17 00:00:00 2001 From: Judson Lester Date: Sun, 3 Dec 2017 04:59:21 -0800 Subject: [PATCH] kitgen implementation (#589) * cmd/kitgen: parse * First stymie * Sketch developing * Broken sketch. * Further progress - still doesn't build * Need to collect ast.Exprs, not strings * Running downhill. * sd: add Stop method to Instancer interface Every implementation (modulo mock/test impls) already provided this method, and so it makes sense to lift it to the interface definition. Closes #566. * Fruitful avenue. Committing for travel. * Needs uniquify for varnames * Track gauge values and support default tags for dogstatsd Dogstatsd doesn't support sending deltas for gauges, so we must maintain the raw gauge value ourselves. * Service generation tests work * add failing test for cloudwatch metrics, that are not reset * Refactor cloudwatch: Reset().Walk() on every Send(), like influx impl does. Note there is a breaking API change, as the cloudwatch object now has optional parameters. * Tolerate that there may not be any lables, if the teststat.FillCounter() did not add any samples. * Use Cloudwatch options in the struct itself, which is cleaner * sd: fix TestDefaultEndpointer flake, hopefully * util/conn: more detail for flaky test * removed deprecated functions changed `NewContext -> NewOutgoingContext` and `FromContext -> FromIncomingContext` as described in [metadta](https://github.com/grpc/grpc-go/blob/master/metadata/metadata.go) * Functions extracted from inline codefile * Cleaner/easier way for user to specify Cloudwatch metric percentiles. * fix test to read quantile metrics with p prefix * test cloudwatch.WithPercentiles() * Handles anonymous fields for all parts of interface * Handles underscore param names, produces compile-able code * do not prefix metrics with 'p', just like it was previously. * Fix for dogstatsd metrics with default tags and no labelValues Signed-off-by: James Hamlin * Converted flat to a layout - proceeding to implement default * Convenience function for formatting to a tree * Fix spelling of deregisters https://en.wiktionary.org/wiki/deregister * Add basic auth middleware * Default layout * Need to handle mutating trees more effectively to do default layout. * Basic Auth: optimize memory allocation. * Fix typo * cache required creds' slices * Clean up comment * Replacing idents successfully * improve error handling and style * Constructs import paths usefully * Some debugging - transit * Set time unit on metrics.Timer * Changes as per code review * fix missing comma in example histogram code * update_deps.bash: handle detached HEAD better * .travis.yml: go1.9 + tip exclusively * circle.yml: go1.9 exclusively * Selectify works - need to rearrange some idents now * Updating golden masters so that they build * Nearly 100% functionality * Updated masters - all seems to work * Now testing that everything builds * Removing AST experiments * Chopping up long sourcecontext.gog * Tiny little notes * auth/jwt: add claim factory to example * auth/jwt: minor gofmt fixes * fix typo in addcli * Cleaning up some type assertion digging * Remove dependency on juju * Downstream usages of ratelimit package * Recreating profilesvc issue * Adding .ignore for rg * flat layout works with defined types * Default layout mostly works - one remaining selectify issue * Debugging replaces - determined we need to do cloning * Halfway through an edit - taking it home * Works with new profilesvc testcases * Try to fix Thrift failure (again) (#630) * Empty commit to trigger CI * examples/addsvc: rebuild with latest thrift * cmd/kitgen: parse * First stymie * Sketch developing * Broken sketch. * Further progress - still doesn't build * Need to collect ast.Exprs, not strings * Running downhill. * Fruitful avenue. Committing for travel. * Needs uniquify for varnames * Service generation tests work * Functions extracted from inline codefile * Handles anonymous fields for all parts of interface * Handles underscore param names, produces compile-able code * Converted flat to a layout - proceeding to implement default * Convenience function for formatting to a tree * Default layout * Need to handle mutating trees more effectively to do default layout. * Replacing idents successfully * Constructs import paths usefully * Some debugging - transit * Selectify works - need to rearrange some idents now * Updating golden masters so that they build * Nearly 100% functionality * Updated masters - all seems to work * Now testing that everything builds * Removing AST experiments * Chopping up long sourcecontext.gog * Tiny little notes * Cleaning up some type assertion digging * Recreating profilesvc issue * Adding .ignore for rg * flat layout works with defined types * Default layout mostly works - one remaining selectify issue * Debugging replaces - determined we need to do cloning * Halfway through an edit - taking it home * Works with new profilesvc testcases --- cmd/kitgen/.ignore | 1 + cmd/kitgen/arg.go | 36 + cmd/kitgen/ast_helpers.go | 208 +++++ cmd/kitgen/ast_templates.go | 11 + cmd/kitgen/deflayout.go | 63 ++ cmd/kitgen/flatlayout.go | 39 + cmd/kitgen/interface.go | 70 ++ cmd/kitgen/main.go | 156 ++++ cmd/kitgen/main_test.go | 113 +++ cmd/kitgen/method.go | 220 +++++ cmd/kitgen/parsevisitor.go | 178 ++++ cmd/kitgen/path_test.go | 43 + cmd/kitgen/replacewalk.go | 759 ++++++++++++++++++ cmd/kitgen/sourcecontext.go | 53 ++ cmd/kitgen/templates/full.go | 60 ++ .../anonfields/default/endpoints/endpoints.go | 28 + .../testdata/anonfields/default/http/http.go | 24 + .../anonfields/default/service/service.go | 12 + cmd/kitgen/testdata/anonfields/flat/gokit.go | 51 ++ cmd/kitgen/testdata/anonfields/in.go | 6 + .../foo/default/endpoints/endpoints.go | 28 + cmd/kitgen/testdata/foo/default/http/http.go | 24 + .../testdata/foo/default/service/service.go | 12 + cmd/kitgen/testdata/foo/flat/gokit.go | 51 ++ cmd/kitgen/testdata/foo/in.go | 5 + .../profilesvc/default/endpoints/endpoints.go | 162 ++++ .../testdata/profilesvc/default/http/http.go | 104 +++ .../profilesvc/default/service/service.go | 45 ++ cmd/kitgen/testdata/profilesvc/flat/gokit.go | 298 +++++++ cmd/kitgen/testdata/profilesvc/in.go | 24 + .../default/endpoints/endpoints.go | 44 + .../stringservice/default/http/http.go | 34 + .../stringservice/default/service/service.go | 15 + .../testdata/stringservice/flat/gokit.go | 80 ++ cmd/kitgen/testdata/stringservice/in.go | 6 + .../default/endpoints/endpoints.go | 27 + .../testdata/underscores/default/http/http.go | 24 + .../underscores/default/service/service.go | 12 + cmd/kitgen/testdata/underscores/flat/gokit.go | 50 ++ cmd/kitgen/testdata/underscores/in.go | 7 + cmd/kitgen/transform.go | 230 ++++++ 41 files changed, 3413 insertions(+) create mode 100644 cmd/kitgen/.ignore create mode 100644 cmd/kitgen/arg.go create mode 100644 cmd/kitgen/ast_helpers.go create mode 100644 cmd/kitgen/ast_templates.go create mode 100644 cmd/kitgen/deflayout.go create mode 100644 cmd/kitgen/flatlayout.go create mode 100644 cmd/kitgen/interface.go create mode 100644 cmd/kitgen/main.go create mode 100644 cmd/kitgen/main_test.go create mode 100644 cmd/kitgen/method.go create mode 100644 cmd/kitgen/parsevisitor.go create mode 100644 cmd/kitgen/path_test.go create mode 100644 cmd/kitgen/replacewalk.go create mode 100644 cmd/kitgen/sourcecontext.go create mode 100644 cmd/kitgen/templates/full.go create mode 100644 cmd/kitgen/testdata/anonfields/default/endpoints/endpoints.go create mode 100644 cmd/kitgen/testdata/anonfields/default/http/http.go create mode 100644 cmd/kitgen/testdata/anonfields/default/service/service.go create mode 100644 cmd/kitgen/testdata/anonfields/flat/gokit.go create mode 100644 cmd/kitgen/testdata/anonfields/in.go create mode 100644 cmd/kitgen/testdata/foo/default/endpoints/endpoints.go create mode 100644 cmd/kitgen/testdata/foo/default/http/http.go create mode 100644 cmd/kitgen/testdata/foo/default/service/service.go create mode 100644 cmd/kitgen/testdata/foo/flat/gokit.go create mode 100644 cmd/kitgen/testdata/foo/in.go create mode 100644 cmd/kitgen/testdata/profilesvc/default/endpoints/endpoints.go create mode 100644 cmd/kitgen/testdata/profilesvc/default/http/http.go create mode 100644 cmd/kitgen/testdata/profilesvc/default/service/service.go create mode 100644 cmd/kitgen/testdata/profilesvc/flat/gokit.go create mode 100644 cmd/kitgen/testdata/profilesvc/in.go create mode 100644 cmd/kitgen/testdata/stringservice/default/endpoints/endpoints.go create mode 100644 cmd/kitgen/testdata/stringservice/default/http/http.go create mode 100644 cmd/kitgen/testdata/stringservice/default/service/service.go create mode 100644 cmd/kitgen/testdata/stringservice/flat/gokit.go create mode 100644 cmd/kitgen/testdata/stringservice/in.go create mode 100644 cmd/kitgen/testdata/underscores/default/endpoints/endpoints.go create mode 100644 cmd/kitgen/testdata/underscores/default/http/http.go create mode 100644 cmd/kitgen/testdata/underscores/default/service/service.go create mode 100644 cmd/kitgen/testdata/underscores/flat/gokit.go create mode 100644 cmd/kitgen/testdata/underscores/in.go create mode 100644 cmd/kitgen/transform.go diff --git a/cmd/kitgen/.ignore b/cmd/kitgen/.ignore new file mode 100644 index 000000000..747f955ca --- /dev/null +++ b/cmd/kitgen/.ignore @@ -0,0 +1 @@ +testdata/*/*/ diff --git a/cmd/kitgen/arg.go b/cmd/kitgen/arg.go new file mode 100644 index 000000000..bcf4e0a5d --- /dev/null +++ b/cmd/kitgen/arg.go @@ -0,0 +1,36 @@ +package main + +import "go/ast" + +type arg struct { + name, asField *ast.Ident + typ ast.Expr +} + +func (a arg) chooseName(scope *ast.Scope) *ast.Ident { + if a.name == nil || scope.Lookup(a.name.Name) != nil { + return inventName(a.typ, scope) + } + return a.name +} + +func (a arg) field(scope *ast.Scope) *ast.Field { + return &ast.Field{ + Names: []*ast.Ident{a.chooseName(scope)}, + Type: a.typ, + } +} + +func (a arg) result() *ast.Field { + return &ast.Field{ + Names: nil, + Type: a.typ, + } +} + +func (a arg) exported() *ast.Field { + return &ast.Field{ + Names: []*ast.Ident{id(export(a.asField.Name))}, + Type: a.typ, + } +} diff --git a/cmd/kitgen/ast_helpers.go b/cmd/kitgen/ast_helpers.go new file mode 100644 index 000000000..ab7c277db --- /dev/null +++ b/cmd/kitgen/ast_helpers.go @@ -0,0 +1,208 @@ +package main + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "strings" + "unicode" +) + +func export(s string) string { + return strings.Title(s) +} + +func unexport(s string) string { + first := true + return strings.Map(func(r rune) rune { + if first { + first = false + return unicode.ToLower(r) + } + return r + }, s) +} + +func inventName(t ast.Expr, scope *ast.Scope) *ast.Ident { + n := baseName(t) + for try := 0; ; try++ { + nstr := pickName(n, try) + obj := ast.NewObj(ast.Var, nstr) + if alt := scope.Insert(obj); alt == nil { + return ast.NewIdent(nstr) + } + } +} + +func baseName(t ast.Expr) string { + switch tt := t.(type) { + default: + panic(fmt.Sprintf("don't know how to choose a base name for %T (%[1]v)", tt)) + case *ast.ArrayType: + return "slice" + case *ast.Ident: + return tt.Name + case *ast.SelectorExpr: + return tt.Sel.Name + } +} + +func pickName(base string, idx int) string { + if idx == 0 { + switch base { + default: + return strings.Split(base, "")[0] + case "Context": + return "ctx" + case "error": + return "err" + } + } + return fmt.Sprintf("%s%d", base, idx) +} + +func scopeWith(names ...string) *ast.Scope { + scope := ast.NewScope(nil) + for _, name := range names { + scope.Insert(ast.NewObj(ast.Var, name)) + } + return scope +} + +type visitFn func(ast.Node, func(ast.Node)) + +func (fn visitFn) Visit(node ast.Node, r func(ast.Node)) Visitor { + fn(node, r) + return fn +} + +func replaceIdent(src ast.Node, named string, with ast.Node) ast.Node { + r := visitFn(func(node ast.Node, replaceWith func(ast.Node)) { + switch id := node.(type) { + case *ast.Ident: + if id.Name == named { + replaceWith(with) + } + } + }) + return WalkReplace(r, src) +} + +func replaceLit(src ast.Node, from, to string) ast.Node { + r := visitFn(func(node ast.Node, replaceWith func(ast.Node)) { + switch lit := node.(type) { + case *ast.BasicLit: + if lit.Value == from { + replaceWith(&ast.BasicLit{Value: to}) + } + } + }) + return WalkReplace(r, src) +} + +func fullAST() *ast.File { + full, err := ASTTemplates.Open("full.go") + if err != nil { + panic(err) + } + f, err := parser.ParseFile(token.NewFileSet(), "templates/full.go", full, parser.DeclarationErrors) + if err != nil { + panic(err) + } + return f +} + +func fetchImports() []*ast.ImportSpec { + return fullAST().Imports +} + +func fetchFuncDecl(name string) *ast.FuncDecl { + f := fullAST() + for _, decl := range f.Decls { + if f, ok := decl.(*ast.FuncDecl); ok && f.Name.Name == name { + return f + } + } + panic(fmt.Errorf("No function called %q in 'templates/full.go'", name)) +} + +func id(name string) *ast.Ident { + return ast.NewIdent(name) +} + +func sel(ids ...*ast.Ident) ast.Expr { + switch len(ids) { + default: + return &ast.SelectorExpr{ + X: sel(ids[:len(ids)-1]...), + Sel: ids[len(ids)-1], + } + case 1: + return ids[0] + case 0: + panic("zero ids to sel()") + } +} + +func typeField(t ast.Expr) *ast.Field { + return &ast.Field{Type: t} +} + +func field(n *ast.Ident, t ast.Expr) *ast.Field { + return &ast.Field{ + Names: []*ast.Ident{n}, + Type: t, + } +} + +func fieldList(list ...*ast.Field) *ast.FieldList { + return &ast.FieldList{List: list} +} + +func mappedFieldList(fn func(arg) *ast.Field, args ...arg) *ast.FieldList { + fl := &ast.FieldList{List: []*ast.Field{}} + for _, a := range args { + fl.List = append(fl.List, fn(a)) + } + return fl +} + +func blockStmt(stmts ...ast.Stmt) *ast.BlockStmt { + return &ast.BlockStmt{ + List: stmts, + } +} + +func structDecl(name *ast.Ident, fields *ast.FieldList) ast.Decl { + return typeDecl(&ast.TypeSpec{ + Name: name, + Type: &ast.StructType{ + Fields: fields, + }, + }) +} + +func typeDecl(ts *ast.TypeSpec) ast.Decl { + return &ast.GenDecl{ + Tok: token.TYPE, + Specs: []ast.Spec{ts}, + } +} + +func pasteStmts(body *ast.BlockStmt, idx int, stmts []ast.Stmt) { + list := body.List + prefix := list[:idx] + suffix := make([]ast.Stmt, len(list)-idx-1) + copy(suffix, list[idx+1:]) + + body.List = append(append(prefix, stmts...), suffix...) +} + +func importFor(is *ast.ImportSpec) *ast.GenDecl { + return &ast.GenDecl{Tok: token.IMPORT, Specs: []ast.Spec{is}} +} + +func importSpec(path string) *ast.ImportSpec { + return &ast.ImportSpec{Path: &ast.BasicLit{Kind: token.STRING, Value: `"` + path + `"`}} +} diff --git a/cmd/kitgen/ast_templates.go b/cmd/kitgen/ast_templates.go new file mode 100644 index 000000000..13aa87c20 --- /dev/null +++ b/cmd/kitgen/ast_templates.go @@ -0,0 +1,11 @@ +// This file was automatically generated based on the contents of *.tmpl +// If you need to update this file, change the contents of those files +// (or add new ones) and run 'go generate' + +package main + +import "golang.org/x/tools/godoc/vfs/mapfs" + +var ASTTemplates = mapfs.New(map[string]string{ + `full.go`: "package foo\n\nimport (\n \"context\"\n \"encoding/json\"\n \"errors\"\n \"net/http\"\n\n \"github.com/go-kit/kit/endpoint\"\n httptransport \"github.com/go-kit/kit/transport/http\"\n)\n\ntype ExampleService struct {\n}\n\ntype ExampleRequest struct {\n I int\n S string\n}\ntype ExampleResponse struct {\n S string\n Err error\n}\n\ntype Endpoints struct {\n ExampleEndpoint endpoint.Endpoint\n}\n\nfunc (f ExampleService) ExampleEndpoint(ctx context.Context, i int, s string) (string, error) {\n panic(errors.New(\"not implemented\"))\n}\n\nfunc makeExampleEndpoint(f ExampleService) endpoint.Endpoint {\n return func(ctx context.Context, request interface{}) (interface{}, error) {\n req := request.(ExampleRequest)\n s, err := f.ExampleEndpoint(ctx, req.I, req.S)\n return ExampleResponse{S: s, Err: err}, nil\n }\n}\n\nfunc inlineHandlerBuilder(m *http.ServeMux, endpoints Endpoints) {\n m.Handle(\"/bar\", httptransport.NewServer(endpoints.ExampleEndpoint, DecodeExampleRequest, EncodeExampleResponse))\n}\n\nfunc NewHTTPHandler(endpoints Endpoints) http.Handler {\n m := http.NewServeMux()\n inlineHandlerBuilder(m, endpoints)\n return m\n}\n\nfunc DecodeExampleRequest(_ context.Context, r *http.Request) (interface{}, error) {\n var req ExampleRequest\n err := json.NewDecoder(r.Body).Decode(&req)\n return req, err\n}\n\nfunc EncodeExampleResponse(_ context.Context, w http.ResponseWriter, response interface{}) error {\n w.Header().Set(\"Content-Type\", \"application/json; charset=utf-8\")\n return json.NewEncoder(w).Encode(response)\n}\n", +}) diff --git a/cmd/kitgen/deflayout.go b/cmd/kitgen/deflayout.go new file mode 100644 index 000000000..27e2fec37 --- /dev/null +++ b/cmd/kitgen/deflayout.go @@ -0,0 +1,63 @@ +package main + +import "path/filepath" + +type deflayout struct { + targetDir string +} + +func (l deflayout) packagePath(sub string) string { + return filepath.Join(l.targetDir, sub) +} + +func (l deflayout) transformAST(ctx *sourceContext) (files, error) { + out := make(outputTree) + + endpoints := out.addFile("endpoints/endpoints.go", "endpoints") + http := out.addFile("http/http.go", "http") + service := out.addFile("service/service.go", "service") + + addImports(endpoints, ctx) + addImports(http, ctx) + addImports(service, ctx) + + for _, typ := range ctx.types { + addType(service, typ) + } + + for _, iface := range ctx.interfaces { //only one... + addStubStruct(service, iface) + + for _, meth := range iface.methods { + addMethod(service, iface, meth) + addRequestStruct(endpoints, meth) + addResponseStruct(endpoints, meth) + addEndpointMaker(endpoints, iface, meth) + } + + addEndpointsStruct(endpoints, iface) + addHTTPHandler(http, iface) + + for _, meth := range iface.methods { + addDecoder(http, meth) + addEncoder(http, meth) + } + + for name := range out { + out[name] = selectify(out[name], "service", iface.stubName().Name, l.packagePath("service")) + for _, meth := range iface.methods { + out[name] = selectify(out[name], "endpoints", meth.requestStructName().Name, l.packagePath("endpoints")) + } + } + } + + for name := range out { + out[name] = selectify(out[name], "endpoints", "Endpoints", l.packagePath("endpoints")) + + for _, typ := range ctx.types { + out[name] = selectify(out[name], "service", typ.Name.Name, l.packagePath("service")) + } + } + + return formatNodes(out) +} diff --git a/cmd/kitgen/flatlayout.go b/cmd/kitgen/flatlayout.go new file mode 100644 index 000000000..fedffa4b8 --- /dev/null +++ b/cmd/kitgen/flatlayout.go @@ -0,0 +1,39 @@ +package main + +import "go/ast" + +type flat struct{} + +func (f flat) transformAST(ctx *sourceContext) (files, error) { + root := &ast.File{ + Name: ctx.pkg, + Decls: []ast.Decl{}, + } + + addImports(root, ctx) + + for _, typ := range ctx.types { + addType(root, typ) + } + + for _, iface := range ctx.interfaces { //only one... + addStubStruct(root, iface) + + for _, meth := range iface.methods { + addMethod(root, iface, meth) + addRequestStruct(root, meth) + addResponseStruct(root, meth) + addEndpointMaker(root, iface, meth) + } + + addEndpointsStruct(root, iface) + addHTTPHandler(root, iface) + + for _, meth := range iface.methods { + addDecoder(root, meth) + addEncoder(root, meth) + } + } + + return formatNodes(outputTree{"gokit.go": root}) +} diff --git a/cmd/kitgen/interface.go b/cmd/kitgen/interface.go new file mode 100644 index 000000000..0c984dfca --- /dev/null +++ b/cmd/kitgen/interface.go @@ -0,0 +1,70 @@ +package main + +import "go/ast" + +// because "interface" is a keyword... +type iface struct { + name, stubname, rcvrName *ast.Ident + methods []method +} + +func (i iface) stubName() *ast.Ident { + return i.stubname +} + +func (i iface) stubStructDecl() ast.Decl { + return structDecl(i.stubName(), &ast.FieldList{}) +} + +func (i iface) endpointsStruct() ast.Decl { + fl := &ast.FieldList{} + for _, m := range i.methods { + fl.List = append(fl.List, &ast.Field{Names: []*ast.Ident{m.name}, Type: sel(id("endpoint"), id("Endpoint"))}) + } + return structDecl(id("Endpoints"), fl) +} + +func (i iface) httpHandler() ast.Decl { + handlerFn := fetchFuncDecl("NewHTTPHandler") + + // does this "inlining" process merit a helper akin to replaceIdent? + handleCalls := []ast.Stmt{} + for _, m := range i.methods { + handleCall := fetchFuncDecl("inlineHandlerBuilder").Body.List[0].(*ast.ExprStmt).X.(*ast.CallExpr) + + handleCall = replaceLit(handleCall, `"/bar"`, `"`+m.pathName()+`"`).(*ast.CallExpr) + handleCall = replaceIdent(handleCall, "ExampleEndpoint", m.name).(*ast.CallExpr) + handleCall = replaceIdent(handleCall, "DecodeExampleRequest", m.decodeFuncName()).(*ast.CallExpr) + handleCall = replaceIdent(handleCall, "EncodeExampleResponse", m.encodeFuncName()).(*ast.CallExpr) + + handleCalls = append(handleCalls, &ast.ExprStmt{X: handleCall}) + } + + pasteStmts(handlerFn.Body, 1, handleCalls) + + return handlerFn +} + +func (i iface) reciever() *ast.Field { + return field(i.receiverName(), i.stubName()) +} + +func (i iface) receiverName() *ast.Ident { + if i.rcvrName != nil { + return i.rcvrName + } + scope := ast.NewScope(nil) + for _, meth := range i.methods { + for _, arg := range meth.params { + if arg.name != nil { + scope.Insert(ast.NewObj(ast.Var, arg.name.Name)) + } + } + for _, arg := range meth.results { + if arg.name != nil { + scope.Insert(ast.NewObj(ast.Var, arg.name.Name)) + } + } + } + return id(unexport(inventName(i.name, scope).Name)) +} diff --git a/cmd/kitgen/main.go b/cmd/kitgen/main.go new file mode 100644 index 000000000..fdfd1fb9c --- /dev/null +++ b/cmd/kitgen/main.go @@ -0,0 +1,156 @@ +package main + +import ( + "flag" + "fmt" + "go/ast" + "go/parser" + "go/token" + "io" + "log" + "os" + "path" + + "github.com/pkg/errors" +) + +// go get github.com/nyarly/inlinefiles +//go:generate inlinefiles --package=main --vfs=ASTTemplates ./templates ast_templates.go + +func usage() string { + return fmt.Sprintf("Usage: %s (try -h)", os.Args[0]) +} + +var ( + help = flag.Bool("h", false, "print this help") + layoutkind = flag.String("repo-layout", "default", "default, flat...") + outdirrel = flag.String("target-dir", ".", "base directory to emit into") + //contextOmittable = flag.Bool("allow-no-context", false, "allow service methods to omit context parameter") +) + +func helpText() { + fmt.Println("USAGE") + fmt.Println(" kitgen [flags] path/to/service.go") + fmt.Println("") + fmt.Println("FLAGS") + flag.PrintDefaults() +} + +func main() { + flag.Parse() + + if *help { + helpText() + os.Exit(0) + } + + outdir := *outdirrel + if !path.IsAbs(*outdirrel) { + wd, err := os.Getwd() + if err != nil { + log.Fatalf("error getting current working directory: %v", err) + } + outdir = path.Join(wd, *outdirrel) + } + + var layout layout + switch *layoutkind { + default: + log.Fatalf("Unrecognized layout kind: %q - try 'default' or 'flat'", *layoutkind) + case "default": + gopath := getGopath() + importBase, err := importPath(outdir, gopath) + if err != nil { + log.Fatal(err) + } + layout = deflayout{targetDir: importBase} + case "flat": + layout = flat{} + } + + if len(os.Args) < 2 { + log.Fatal(usage()) + } + filename := flag.Arg(0) + file, err := os.Open(filename) + if err != nil { + log.Fatalf("error while opening %q: %v", filename, err) + } + + tree, err := process(filename, file, layout) + if err != nil { + log.Fatal(err) + } + + err = splat(outdir, tree) + if err != nil { + log.Fatal(err) + } +} + +func process(filename string, source io.Reader, layout layout) (files, error) { + f, err := parseFile(filename, source) + if err != nil { + return nil, errors.Wrapf(err, "parsing input %q", filename) + } + + context, err := extractContext(f) + if err != nil { + return nil, errors.Wrapf(err, "examining input file %q", filename) + } + + tree, err := layout.transformAST(context) + if err != nil { + return nil, errors.Wrapf(err, "generating AST") + } + return tree, nil +} + +/* + buf, err := formatNode(dest) + if err != nil { + return nil, errors.Wrapf(err, "formatting") + } + return buf, nil +} +*/ + +func parseFile(fname string, source io.Reader) (ast.Node, error) { + f, err := parser.ParseFile(token.NewFileSet(), fname, source, parser.DeclarationErrors) + if err != nil { + return nil, err + } + return f, nil +} + +func extractContext(f ast.Node) (*sourceContext, error) { + context := &sourceContext{} + visitor := &parseVisitor{src: context} + + ast.Walk(visitor, f) + + return context, context.validate() +} + +func splat(dir string, tree files) error { + for fn, buf := range tree { + if err := splatFile(path.Join(dir, fn), buf); err != nil { + return err + } + } + return nil +} + +func splatFile(target string, buf io.Reader) error { + err := os.MkdirAll(path.Dir(target), os.ModePerm) + if err != nil { + return errors.Wrapf(err, "Couldn't create directory for %q", target) + } + f, err := os.Create(target) + if err != nil { + return errors.Wrapf(err, "Couldn't create file %q", target) + } + defer f.Close() + _, err = io.Copy(f, buf) + return errors.Wrapf(err, "Error writing data to file %q", target) +} diff --git a/cmd/kitgen/main_test.go b/cmd/kitgen/main_test.go new file mode 100644 index 000000000..4ef5013a8 --- /dev/null +++ b/cmd/kitgen/main_test.go @@ -0,0 +1,113 @@ +package main + +import ( + "bytes" + "flag" + "fmt" + "io" + "io/ioutil" + "os/exec" + "path/filepath" + "strings" + "testing" +) + +var update = flag.Bool("update", false, "update golden files") + +func TestProcess(t *testing.T) { + cases, err := filepath.Glob("testdata/*") + if err != nil { + t.Fatal(err) + } + + laidout := func(t *testing.T, inpath, dir, kind string, layout layout, in []byte) { + t.Run(kind, func(t *testing.T) { + targetDir := filepath.Join(dir, kind) + tree, err := process(inpath, bytes.NewBuffer(in), layout) + if err != nil { + t.Fatal(inpath, fmt.Sprintf("%+#v", err)) + } + + if *update { + err := splat(targetDir, tree) + if err != nil { + t.Fatal(kind, err) + } + // otherwise we need to do some tomfoolery with resetting buffers + // I'm willing to just run the tests again - besides, we shouldn't be + // regerating the golden files that often + t.Error("Updated outputs - DID NOT COMPARE! (run tests again without -update)") + return + } + + for filename, buf := range tree { + actual, err := ioutil.ReadAll(buf) + if err != nil { + t.Fatal(kind, filename, err) + } + + outpath := filepath.Join(targetDir, filename) + + expected, err := ioutil.ReadFile(outpath) + if err != nil { + t.Fatal(outpath, err) + } + + if !bytes.Equal(expected, actual) { + name := kind + filename + name = strings.Replace(name, "/", "-", -1) + + errfile, err := ioutil.TempFile("", name) + if err != nil { + t.Fatal("opening tempfile for output", err) + } + io.WriteString(errfile, string(actual)) + + diffCmd := exec.Command("diff", outpath, errfile.Name()) + diffOut, _ := diffCmd.Output() + t.Log(string(diffOut)) + t.Errorf("Processing output didn't match %q. Results recorded in %q.", outpath, errfile.Name()) + } + } + + if !t.Failed() { + build := exec.Command("go", "build", "./...") + build.Dir = targetDir + out, err := build.CombinedOutput() + if err != nil { + t.Fatalf("Cannot build output: %v\n%s", err, string(out)) + } + } + }) + + } + + testcase := func(dir string) { + name := filepath.Base(dir) + t.Run(name, func(t *testing.T) { + inpath := filepath.Join(dir, "in.go") + + in, err := ioutil.ReadFile(inpath) + if err != nil { + t.Fatal(inpath, err) + } + laidout(t, inpath, dir, "flat", flat{}, in) + laidout(t, inpath, dir, "default", deflayout{ + targetDir: filepath.Join("github.com/go-kit/kit/cmd/kitgen", dir, "default"), + }, in) + }) + } + + for _, dir := range cases { + testcase(dir) + } +} + +func TestTemplatesBuild(t *testing.T) { + build := exec.Command("go", "build", "./...") + build.Dir = "templates" + out, err := build.CombinedOutput() + if err != nil { + t.Fatal(err, "\n", string(out)) + } +} diff --git a/cmd/kitgen/method.go b/cmd/kitgen/method.go new file mode 100644 index 000000000..14238dd40 --- /dev/null +++ b/cmd/kitgen/method.go @@ -0,0 +1,220 @@ +package main + +import ( + "go/ast" + "go/token" + "strings" +) + +type method struct { + name *ast.Ident + params []arg + results []arg + structsResolved bool +} + +func (m method) definition(ifc iface) ast.Decl { + notImpl := fetchFuncDecl("ExampleEndpoint") + + notImpl.Name = m.name + notImpl.Recv = fieldList(ifc.reciever()) + scope := scopeWith(notImpl.Recv.List[0].Names[0].Name) + notImpl.Type.Params = m.funcParams(scope) + notImpl.Type.Results = m.funcResults() + + return notImpl +} + +func (m method) endpointMaker(ifc iface) ast.Decl { + endpointFn := fetchFuncDecl("makeExampleEndpoint") + scope := scopeWith("ctx", "req", ifc.receiverName().Name) + + anonFunc := endpointFn.Body.List[0].(*ast.ReturnStmt).Results[0].(*ast.FuncLit) + if !m.hasContext() { + // strip context param from endpoint function + anonFunc.Type.Params.List = anonFunc.Type.Params.List[1:] + } + + anonFunc = replaceIdent(anonFunc, "ExampleRequest", m.requestStructName()).(*ast.FuncLit) + callMethod := m.called(ifc, scope, "ctx", "req") + anonFunc.Body.List[1] = callMethod + anonFunc.Body.List[2].(*ast.ReturnStmt).Results[0] = m.wrapResult(callMethod.Lhs) + + endpointFn.Body.List[0].(*ast.ReturnStmt).Results[0] = anonFunc + endpointFn.Name = m.endpointMakerName() + endpointFn.Type.Params = fieldList(ifc.reciever()) + endpointFn.Type.Results = fieldList(typeField(sel(id("endpoint"), id("Endpoint")))) + return endpointFn +} + +func (m method) pathName() string { + return "/" + strings.ToLower(m.name.Name) +} + +func (m method) encodeFuncName() *ast.Ident { + return id("Encode" + m.name.Name + "Response") +} + +func (m method) decodeFuncName() *ast.Ident { + return id("Decode" + m.name.Name + "Request") +} + +func (m method) resultNames(scope *ast.Scope) []*ast.Ident { + ids := []*ast.Ident{} + for _, rz := range m.results { + ids = append(ids, rz.chooseName(scope)) + } + return ids +} + +func (m method) called(ifc iface, scope *ast.Scope, ctxName, spreadStruct string) *ast.AssignStmt { + m.resolveStructNames() + + resNamesExpr := []ast.Expr{} + for _, r := range m.resultNames(scope) { + resNamesExpr = append(resNamesExpr, ast.Expr(r)) + } + + arglist := []ast.Expr{} + if m.hasContext() { + arglist = append(arglist, id(ctxName)) + } + ssid := id(spreadStruct) + for _, f := range m.requestStructFields().List { + arglist = append(arglist, sel(ssid, f.Names[0])) + } + + return &ast.AssignStmt{ + Lhs: resNamesExpr, + Tok: token.DEFINE, + Rhs: []ast.Expr{ + &ast.CallExpr{ + Fun: sel(ifc.receiverName(), m.name), + Args: arglist, + }, + }, + } +} + +func (m method) wrapResult(results []ast.Expr) ast.Expr { + kvs := []ast.Expr{} + m.resolveStructNames() + + for i, a := range m.results { + kvs = append(kvs, &ast.KeyValueExpr{ + Key: ast.NewIdent(export(a.asField.Name)), + Value: results[i], + }) + } + return &ast.CompositeLit{ + Type: m.responseStructName(), + Elts: kvs, + } +} + +func (m method) resolveStructNames() { + if m.structsResolved { + return + } + m.structsResolved = true + scope := ast.NewScope(nil) + for i, p := range m.params { + p.asField = p.chooseName(scope) + m.params[i] = p + } + scope = ast.NewScope(nil) + for i, r := range m.results { + r.asField = r.chooseName(scope) + m.results[i] = r + } +} + +func (m method) decoderFunc() ast.Decl { + fn := fetchFuncDecl("DecodeExampleRequest") + fn.Name = m.decodeFuncName() + fn = replaceIdent(fn, "ExampleRequest", m.requestStructName()).(*ast.FuncDecl) + return fn +} + +func (m method) encoderFunc() ast.Decl { + fn := fetchFuncDecl("EncodeExampleResponse") + fn.Name = m.encodeFuncName() + return fn +} + +func (m method) endpointMakerName() *ast.Ident { + return id("make" + m.name.Name + "Endpoint") +} + +func (m method) requestStruct() ast.Decl { + m.resolveStructNames() + return structDecl(m.requestStructName(), m.requestStructFields()) +} + +func (m method) responseStruct() ast.Decl { + m.resolveStructNames() + return structDecl(m.responseStructName(), m.responseStructFields()) +} + +func (m method) hasContext() bool { + if len(m.params) < 1 { + return false + } + carg := m.params[0].typ + // ugh. this is maybe okay for the one-off, but a general case for matching + // types would be helpful + if sel, is := carg.(*ast.SelectorExpr); is && sel.Sel.Name == "Context" { + if id, is := sel.X.(*ast.Ident); is && id.Name == "context" { + return true + } + } + return false +} + +func (m method) nonContextParams() []arg { + if m.hasContext() { + return m.params[1:] + } + return m.params +} + +func (m method) funcParams(scope *ast.Scope) *ast.FieldList { + parms := &ast.FieldList{} + if m.hasContext() { + parms.List = []*ast.Field{{ + Names: []*ast.Ident{ast.NewIdent("ctx")}, + Type: sel(id("context"), id("Context")), + }} + scope.Insert(ast.NewObj(ast.Var, "ctx")) + } + parms.List = append(parms.List, mappedFieldList(func(a arg) *ast.Field { + return a.field(scope) + }, m.nonContextParams()...).List...) + return parms +} + +func (m method) funcResults() *ast.FieldList { + return mappedFieldList(func(a arg) *ast.Field { + return a.result() + }, m.results...) +} + +func (m method) requestStructName() *ast.Ident { + return id(export(m.name.Name) + "Request") +} + +func (m method) requestStructFields() *ast.FieldList { + return mappedFieldList(func(a arg) *ast.Field { + return a.exported() + }, m.nonContextParams()...) +} + +func (m method) responseStructName() *ast.Ident { + return id(export(m.name.Name) + "Response") +} + +func (m method) responseStructFields() *ast.FieldList { + return mappedFieldList(func(a arg) *ast.Field { + return a.exported() + }, m.results...) +} diff --git a/cmd/kitgen/parsevisitor.go b/cmd/kitgen/parsevisitor.go new file mode 100644 index 000000000..aa5131343 --- /dev/null +++ b/cmd/kitgen/parsevisitor.go @@ -0,0 +1,178 @@ +package main + +import ( + "go/ast" +) + +type ( + parseVisitor struct { + src *sourceContext + } + + typeSpecVisitor struct { + src *sourceContext + node *ast.TypeSpec + iface *iface + name *ast.Ident + } + + interfaceTypeVisitor struct { + node *ast.TypeSpec + ts *typeSpecVisitor + methods []method + } + + methodVisitor struct { + depth int + node *ast.TypeSpec + list *[]method + name *ast.Ident + params, results *[]arg + isMethod bool + } + + argListVisitor struct { + list *[]arg + } + + argVisitor struct { + node *ast.TypeSpec + parts []ast.Expr + list *[]arg + } +) + +func (v *parseVisitor) Visit(n ast.Node) ast.Visitor { + switch rn := n.(type) { + default: + return v + case *ast.File: + v.src.pkg = rn.Name + return v + case *ast.ImportSpec: + v.src.imports = append(v.src.imports, rn) + return nil + + case *ast.TypeSpec: + switch rn.Type.(type) { + default: + v.src.types = append(v.src.types, rn) + case *ast.InterfaceType: + // can't output interfaces + // because they'd conflict with our implementations + } + return &typeSpecVisitor{src: v.src, node: rn} + } +} + +/* +package foo + +type FooService interface { + Bar(ctx context.Context, i int, s string) (string, error) +} +*/ + +func (v *typeSpecVisitor) Visit(n ast.Node) ast.Visitor { + switch rn := n.(type) { + default: + return v + case *ast.Ident: + if v.name == nil { + v.name = rn + } + return v + case *ast.InterfaceType: + return &interfaceTypeVisitor{ts: v, methods: []method{}} + case nil: + if v.iface != nil { + v.iface.name = v.name + sn := *v.name + v.iface.stubname = &sn + v.iface.stubname.Name = v.name.String() + v.src.interfaces = append(v.src.interfaces, *v.iface) + } + return nil + } +} + +func (v *interfaceTypeVisitor) Visit(n ast.Node) ast.Visitor { + switch n.(type) { + default: + return v + case *ast.Field: + return &methodVisitor{list: &v.methods} + case nil: + v.ts.iface = &iface{methods: v.methods} + return nil + } +} + +func (v *methodVisitor) Visit(n ast.Node) ast.Visitor { + switch rn := n.(type) { + default: + v.depth++ + return v + case *ast.Ident: + if rn.IsExported() { + v.name = rn + } + v.depth++ + return v + case *ast.FuncType: + v.depth++ + v.isMethod = true + return v + case *ast.FieldList: + if v.params == nil { + v.params = &[]arg{} + return &argListVisitor{list: v.params} + } + if v.results == nil { + v.results = &[]arg{} + } + return &argListVisitor{list: v.results} + case nil: + v.depth-- + if v.depth == 0 && v.isMethod && v.name != nil { + *v.list = append(*v.list, method{name: v.name, params: *v.params, results: *v.results}) + } + return nil + } +} + +func (v *argListVisitor) Visit(n ast.Node) ast.Visitor { + switch n.(type) { + default: + return nil + case *ast.Field: + return &argVisitor{list: v.list} + } +} + +func (v *argVisitor) Visit(n ast.Node) ast.Visitor { + switch t := n.(type) { + case *ast.CommentGroup, *ast.BasicLit: + return nil + case *ast.Ident: //Expr -> everything, but clarity + if t.Name != "_" { + v.parts = append(v.parts, t) + } + case ast.Expr: + v.parts = append(v.parts, t) + case nil: + names := v.parts[:len(v.parts)-1] + tp := v.parts[len(v.parts)-1] + if len(names) == 0 { + *v.list = append(*v.list, arg{typ: tp}) + return nil + } + for _, n := range names { + *v.list = append(*v.list, arg{ + name: n.(*ast.Ident), + typ: tp, + }) + } + } + return nil +} diff --git a/cmd/kitgen/path_test.go b/cmd/kitgen/path_test.go new file mode 100644 index 000000000..371ded1d1 --- /dev/null +++ b/cmd/kitgen/path_test.go @@ -0,0 +1,43 @@ +package main + +import ( + "fmt" + "strings" + "testing" +) + +func TestImportPath(t *testing.T) { + testcase := func(gopath, targetpath, expected string) { + t.Run(fmt.Sprintf("%q + %q", gopath, targetpath), func(t *testing.T) { + actual, err := importPath(targetpath, gopath) + if err != nil { + t.Fatalf("Expected no error, got %q", err) + } + if actual != expected { + t.Errorf("Expected %q, got %q", expected, actual) + } + }) + } + + testcase("/gopath/", "/gopath/src/somewhere", "somewhere") + testcase("/gopath", "/gopath/src/somewhere", "somewhere") + testcase("/gopath:/other", "/gopath/src/somewhere", "somewhere") + testcase("/other:/gopath/", "/gopath/src/somewhere", "somewhere") +} + +func TestImportPathSadpath(t *testing.T) { + testcase := func(gopath, targetpath, expected string) { + t.Run(fmt.Sprintf("%q + %q", gopath, targetpath), func(t *testing.T) { + actual, err := importPath(targetpath, gopath) + if actual != "" { + t.Errorf("Expected empty path, got %q", actual) + } + if strings.Index(err.Error(), expected) == -1 { + t.Errorf("Expected %q to include %q", err, expected) + } + }) + } + + testcase("", "/gopath/src/somewhere", "is not in") + testcase("", "./somewhere", "not an absolute") +} diff --git a/cmd/kitgen/replacewalk.go b/cmd/kitgen/replacewalk.go new file mode 100644 index 000000000..f6b70dcd9 --- /dev/null +++ b/cmd/kitgen/replacewalk.go @@ -0,0 +1,759 @@ +package main + +import ( + "fmt" + "go/ast" +) + +// A Visitor's Visit method is invoked for each node encountered by walkToReplace. +// If the result visitor w is not nil, walkToReplace visits each of the children +// of node with the visitor w, followed by a call of w.Visit(nil). +type Visitor interface { + Visit(node ast.Node, replace func(ast.Node)) (w Visitor) +} + +// Helper functions for common node lists. They may be empty. + +func walkIdentList(v Visitor, list []*ast.Ident) { + for i, x := range list { + walkToReplace(v, x, func(r ast.Node) { + list[i] = r.(*ast.Ident) + }) + } +} + +func walkExprList(v Visitor, list []ast.Expr) { + for i, x := range list { + walkToReplace(v, x, func(r ast.Node) { + list[i] = r.(ast.Expr) + }) + } +} + +func walkStmtList(v Visitor, list []ast.Stmt) { + for i, x := range list { + walkToReplace(v, x, func(r ast.Node) { + list[i] = r.(ast.Stmt) + }) + } +} + +func walkDeclList(v Visitor, list []ast.Decl) { + for i, x := range list { + walkToReplace(v, x, func(r ast.Node) { + list[i] = r.(ast.Decl) + }) + } +} + +// WalkToReplace traverses an AST in depth-first order: It starts by calling +// v.Visit(node); node must not be nil. If the visitor w returned by +// v.Visit(node) is not nil, walkToReplace is invoked recursively with visitor +// w for each of the non-nil children of node, followed by a call of +// w.Visit(nil). +func WalkReplace(v Visitor, node ast.Node) (replacement ast.Node) { + walkToReplace(v, node, func(r ast.Node) { + replacement = r + }) + return +} + +func walkToReplace(v Visitor, node ast.Node, replace func(ast.Node)) { + if v == nil { + return + } + var replacement ast.Node + repl := func(r ast.Node) { + replacement = r + replace(r) + } + + v = v.Visit(node, repl) + + if replacement != nil { + return + } + + // walk children + // (the order of the cases matches the order + // of the corresponding node types in ast.go) + switch n := node.(type) { + + // These are all leaves, so there's no sub-walk to do. + // We just need to replace them on their parent with a copy. + case *ast.Comment: + cpy := *n + replace(&cpy) + case *ast.BadExpr: + cpy := *n + replace(&cpy) + case *ast.Ident: + cpy := *n + replace(&cpy) + case *ast.BasicLit: + cpy := *n + replace(&cpy) + case *ast.BadDecl: + cpy := *n + replace(&cpy) + case *ast.EmptyStmt: + cpy := *n + replace(&cpy) + case *ast.BadStmt: + cpy := *n + replace(&cpy) + + case *ast.CommentGroup: + cpy := *n + + if n.List != nil { + cpy.List = make([]*ast.Comment, len(n.List)) + copy(cpy.List, n.List) + } + + for i, c := range cpy.List { + walkToReplace(v, c, func(r ast.Node) { + cpy.List[i] = r.(*ast.Comment) + }) + } + replace(&cpy) + + case *ast.Field: + cpy := *n + if n.Names != nil { + cpy.Names = make([]*ast.Ident, len(n.Names)) + copy(cpy.Names, n.Names) + } + + if cpy.Doc != nil { + walkToReplace(v, cpy.Doc, func(r ast.Node) { + cpy.Doc = r.(*ast.CommentGroup) + }) + } + walkIdentList(v, cpy.Names) + + walkToReplace(v, cpy.Type, func(r ast.Node) { + cpy.Type = r.(ast.Expr) + }) + if cpy.Tag != nil { + walkToReplace(v, cpy.Tag, func(r ast.Node) { + cpy.Tag = r.(*ast.BasicLit) + }) + } + if cpy.Comment != nil { + walkToReplace(v, cpy.Comment, func(r ast.Node) { + cpy.Comment = r.(*ast.CommentGroup) + }) + } + replace(&cpy) + + case *ast.FieldList: + cpy := *n + if n.List != nil { + cpy.List = make([]*ast.Field, len(n.List)) + copy(cpy.List, n.List) + } + + for i, f := range cpy.List { + walkToReplace(v, f, func(r ast.Node) { + cpy.List[i] = r.(*ast.Field) + }) + } + + replace(&cpy) + + case *ast.Ellipsis: + cpy := *n + + if cpy.Elt != nil { + walkToReplace(v, cpy.Elt, func(r ast.Node) { + cpy.Elt = r.(ast.Expr) + }) + } + + replace(&cpy) + + case *ast.FuncLit: + cpy := *n + walkToReplace(v, cpy.Type, func(r ast.Node) { + cpy.Type = r.(*ast.FuncType) + }) + walkToReplace(v, cpy.Body, func(r ast.Node) { + cpy.Body = r.(*ast.BlockStmt) + }) + + replace(&cpy) + case *ast.CompositeLit: + cpy := *n + if n.Elts != nil { + cpy.Elts = make([]ast.Expr, len(n.Elts)) + copy(cpy.Elts, n.Elts) + } + + if cpy.Type != nil { + walkToReplace(v, cpy.Type, func(r ast.Node) { + cpy.Type = r.(ast.Expr) + }) + } + walkExprList(v, cpy.Elts) + + replace(&cpy) + case *ast.ParenExpr: + cpy := *n + walkToReplace(v, cpy.X, func(r ast.Node) { + cpy.X = r.(ast.Expr) + }) + + replace(&cpy) + case *ast.SelectorExpr: + cpy := *n + walkToReplace(v, cpy.X, func(r ast.Node) { + cpy.X = r.(ast.Expr) + }) + walkToReplace(v, cpy.Sel, func(r ast.Node) { + cpy.Sel = r.(*ast.Ident) + }) + + replace(&cpy) + case *ast.IndexExpr: + cpy := *n + walkToReplace(v, cpy.X, func(r ast.Node) { + cpy.X = r.(ast.Expr) + }) + walkToReplace(v, cpy.Index, func(r ast.Node) { + cpy.Index = r.(ast.Expr) + }) + + replace(&cpy) + case *ast.SliceExpr: + cpy := *n + walkToReplace(v, cpy.X, func(r ast.Node) { + cpy.X = r.(ast.Expr) + }) + if cpy.Low != nil { + walkToReplace(v, cpy.Low, func(r ast.Node) { + cpy.Low = r.(ast.Expr) + }) + } + if cpy.High != nil { + walkToReplace(v, cpy.High, func(r ast.Node) { + cpy.High = r.(ast.Expr) + }) + } + if cpy.Max != nil { + walkToReplace(v, cpy.Max, func(r ast.Node) { + cpy.Max = r.(ast.Expr) + }) + } + + replace(&cpy) + case *ast.TypeAssertExpr: + cpy := *n + walkToReplace(v, cpy.X, func(r ast.Node) { + cpy.X = r.(ast.Expr) + }) + if cpy.Type != nil { + walkToReplace(v, cpy.Type, func(r ast.Node) { + cpy.Type = r.(ast.Expr) + }) + } + replace(&cpy) + case *ast.CallExpr: + cpy := *n + if n.Args != nil { + cpy.Args = make([]ast.Expr, len(n.Args)) + copy(cpy.Args, n.Args) + } + + walkToReplace(v, cpy.Fun, func(r ast.Node) { + cpy.Fun = r.(ast.Expr) + }) + walkExprList(v, cpy.Args) + + replace(&cpy) + case *ast.StarExpr: + cpy := *n + walkToReplace(v, cpy.X, func(r ast.Node) { + cpy.X = r.(ast.Expr) + }) + + replace(&cpy) + case *ast.UnaryExpr: + cpy := *n + walkToReplace(v, cpy.X, func(r ast.Node) { + cpy.X = r.(ast.Expr) + }) + + replace(&cpy) + case *ast.BinaryExpr: + cpy := *n + walkToReplace(v, cpy.X, func(r ast.Node) { + cpy.X = r.(ast.Expr) + }) + walkToReplace(v, cpy.Y, func(r ast.Node) { + cpy.Y = r.(ast.Expr) + }) + + replace(&cpy) + case *ast.KeyValueExpr: + cpy := *n + walkToReplace(v, cpy.Key, func(r ast.Node) { + cpy.Key = r.(ast.Expr) + }) + walkToReplace(v, cpy.Value, func(r ast.Node) { + cpy.Value = r.(ast.Expr) + }) + + replace(&cpy) + + // Types + case *ast.ArrayType: + cpy := *n + if cpy.Len != nil { + walkToReplace(v, cpy.Len, func(r ast.Node) { + cpy.Len = r.(ast.Expr) + }) + } + walkToReplace(v, cpy.Elt, func(r ast.Node) { + cpy.Elt = r.(ast.Expr) + }) + + replace(&cpy) + case *ast.StructType: + cpy := *n + walkToReplace(v, cpy.Fields, func(r ast.Node) { + cpy.Fields = r.(*ast.FieldList) + }) + + replace(&cpy) + case *ast.FuncType: + cpy := *n + if cpy.Params != nil { + walkToReplace(v, cpy.Params, func(r ast.Node) { + cpy.Params = r.(*ast.FieldList) + }) + } + if cpy.Results != nil { + walkToReplace(v, cpy.Results, func(r ast.Node) { + cpy.Results = r.(*ast.FieldList) + }) + } + + replace(&cpy) + case *ast.InterfaceType: + cpy := *n + walkToReplace(v, cpy.Methods, func(r ast.Node) { + cpy.Methods = r.(*ast.FieldList) + }) + + replace(&cpy) + case *ast.MapType: + cpy := *n + walkToReplace(v, cpy.Key, func(r ast.Node) { + cpy.Key = r.(ast.Expr) + }) + walkToReplace(v, cpy.Value, func(r ast.Node) { + cpy.Value = r.(ast.Expr) + }) + + replace(&cpy) + case *ast.ChanType: + cpy := *n + walkToReplace(v, cpy.Value, func(r ast.Node) { + cpy.Value = r.(ast.Expr) + }) + + replace(&cpy) + case *ast.DeclStmt: + cpy := *n + walkToReplace(v, cpy.Decl, func(r ast.Node) { + cpy.Decl = r.(ast.Decl) + }) + + replace(&cpy) + case *ast.LabeledStmt: + cpy := *n + walkToReplace(v, cpy.Label, func(r ast.Node) { + cpy.Label = r.(*ast.Ident) + }) + walkToReplace(v, cpy.Stmt, func(r ast.Node) { + cpy.Stmt = r.(ast.Stmt) + }) + + replace(&cpy) + case *ast.ExprStmt: + cpy := *n + walkToReplace(v, cpy.X, func(r ast.Node) { + cpy.X = r.(ast.Expr) + }) + + replace(&cpy) + case *ast.SendStmt: + cpy := *n + walkToReplace(v, cpy.Chan, func(r ast.Node) { + cpy.Chan = r.(ast.Expr) + }) + walkToReplace(v, cpy.Value, func(r ast.Node) { + cpy.Value = r.(ast.Expr) + }) + + replace(&cpy) + case *ast.IncDecStmt: + cpy := *n + walkToReplace(v, cpy.X, func(r ast.Node) { + cpy.X = r.(ast.Expr) + }) + + replace(&cpy) + case *ast.AssignStmt: + cpy := *n + if n.Lhs != nil { + cpy.Lhs = make([]ast.Expr, len(n.Lhs)) + copy(cpy.Lhs, n.Lhs) + } + if n.Rhs != nil { + cpy.Rhs = make([]ast.Expr, len(n.Rhs)) + copy(cpy.Rhs, n.Rhs) + } + + walkExprList(v, cpy.Lhs) + walkExprList(v, cpy.Rhs) + + replace(&cpy) + case *ast.GoStmt: + cpy := *n + walkToReplace(v, cpy.Call, func(r ast.Node) { + cpy.Call = r.(*ast.CallExpr) + }) + + replace(&cpy) + case *ast.DeferStmt: + cpy := *n + walkToReplace(v, cpy.Call, func(r ast.Node) { + cpy.Call = r.(*ast.CallExpr) + }) + + replace(&cpy) + case *ast.ReturnStmt: + cpy := *n + if n.Results != nil { + cpy.Results = make([]ast.Expr, len(n.Results)) + copy(cpy.Results, n.Results) + } + + walkExprList(v, cpy.Results) + + replace(&cpy) + case *ast.BranchStmt: + cpy := *n + if cpy.Label != nil { + walkToReplace(v, cpy.Label, func(r ast.Node) { + cpy.Label = r.(*ast.Ident) + }) + } + + replace(&cpy) + case *ast.BlockStmt: + cpy := *n + if n.List != nil { + cpy.List = make([]ast.Stmt, len(n.List)) + copy(cpy.List, n.List) + } + + walkStmtList(v, cpy.List) + + replace(&cpy) + case *ast.IfStmt: + cpy := *n + + if cpy.Init != nil { + walkToReplace(v, cpy.Init, func(r ast.Node) { + cpy.Init = r.(ast.Stmt) + }) + } + walkToReplace(v, cpy.Cond, func(r ast.Node) { + cpy.Cond = r.(ast.Expr) + }) + walkToReplace(v, cpy.Body, func(r ast.Node) { + cpy.Body = r.(*ast.BlockStmt) + }) + if cpy.Else != nil { + walkToReplace(v, cpy.Else, func(r ast.Node) { + cpy.Else = r.(ast.Stmt) + }) + } + + replace(&cpy) + case *ast.CaseClause: + cpy := *n + if n.List != nil { + cpy.List = make([]ast.Expr, len(n.List)) + copy(cpy.List, n.List) + } + if n.Body != nil { + cpy.Body = make([]ast.Stmt, len(n.Body)) + copy(cpy.Body, n.Body) + } + + walkExprList(v, cpy.List) + walkStmtList(v, cpy.Body) + + replace(&cpy) + case *ast.SwitchStmt: + cpy := *n + if cpy.Init != nil { + walkToReplace(v, cpy.Init, func(r ast.Node) { + cpy.Init = r.(ast.Stmt) + }) + } + if cpy.Tag != nil { + walkToReplace(v, cpy.Tag, func(r ast.Node) { + cpy.Tag = r.(ast.Expr) + }) + } + walkToReplace(v, cpy.Body, func(r ast.Node) { + cpy.Body = r.(*ast.BlockStmt) + }) + + replace(&cpy) + case *ast.TypeSwitchStmt: + cpy := *n + if cpy.Init != nil { + walkToReplace(v, cpy.Init, func(r ast.Node) { + cpy.Init = r.(ast.Stmt) + }) + } + walkToReplace(v, cpy.Assign, func(r ast.Node) { + cpy.Assign = r.(ast.Stmt) + }) + walkToReplace(v, cpy.Body, func(r ast.Node) { + cpy.Body = r.(*ast.BlockStmt) + }) + + replace(&cpy) + case *ast.CommClause: + cpy := *n + if n.Body != nil { + cpy.Body = make([]ast.Stmt, len(n.Body)) + copy(cpy.Body, n.Body) + } + + if cpy.Comm != nil { + walkToReplace(v, cpy.Comm, func(r ast.Node) { + cpy.Comm = r.(ast.Stmt) + }) + } + walkStmtList(v, cpy.Body) + + replace(&cpy) + case *ast.SelectStmt: + cpy := *n + walkToReplace(v, cpy.Body, func(r ast.Node) { + cpy.Body = r.(*ast.BlockStmt) + }) + + replace(&cpy) + case *ast.ForStmt: + cpy := *n + if cpy.Init != nil { + walkToReplace(v, cpy.Init, func(r ast.Node) { + cpy.Init = r.(ast.Stmt) + }) + } + if cpy.Cond != nil { + walkToReplace(v, cpy.Cond, func(r ast.Node) { + cpy.Cond = r.(ast.Expr) + }) + } + if cpy.Post != nil { + walkToReplace(v, cpy.Post, func(r ast.Node) { + cpy.Post = r.(ast.Stmt) + }) + } + walkToReplace(v, cpy.Body, func(r ast.Node) { + cpy.Body = r.(*ast.BlockStmt) + }) + + replace(&cpy) + case *ast.RangeStmt: + cpy := *n + if cpy.Key != nil { + walkToReplace(v, cpy.Key, func(r ast.Node) { + cpy.Key = r.(ast.Expr) + }) + } + if cpy.Value != nil { + walkToReplace(v, cpy.Value, func(r ast.Node) { + cpy.Value = r.(ast.Expr) + }) + } + walkToReplace(v, cpy.X, func(r ast.Node) { + cpy.X = r.(ast.Expr) + }) + walkToReplace(v, cpy.Body, func(r ast.Node) { + cpy.Body = r.(*ast.BlockStmt) + }) + + // Declarations + replace(&cpy) + case *ast.ImportSpec: + cpy := *n + if cpy.Doc != nil { + walkToReplace(v, cpy.Doc, func(r ast.Node) { + cpy.Doc = r.(*ast.CommentGroup) + }) + } + if cpy.Name != nil { + walkToReplace(v, cpy.Name, func(r ast.Node) { + cpy.Name = r.(*ast.Ident) + }) + } + walkToReplace(v, cpy.Path, func(r ast.Node) { + cpy.Path = r.(*ast.BasicLit) + }) + if cpy.Comment != nil { + walkToReplace(v, cpy.Comment, func(r ast.Node) { + cpy.Comment = r.(*ast.CommentGroup) + }) + } + + replace(&cpy) + case *ast.ValueSpec: + cpy := *n + if n.Names != nil { + cpy.Names = make([]*ast.Ident, len(n.Names)) + copy(cpy.Names, n.Names) + } + if n.Values != nil { + cpy.Values = make([]ast.Expr, len(n.Values)) + copy(cpy.Values, n.Values) + } + + if cpy.Doc != nil { + walkToReplace(v, cpy.Doc, func(r ast.Node) { + cpy.Doc = r.(*ast.CommentGroup) + }) + } + + walkIdentList(v, cpy.Names) + + if cpy.Type != nil { + walkToReplace(v, cpy.Type, func(r ast.Node) { + cpy.Type = r.(ast.Expr) + }) + } + + walkExprList(v, cpy.Values) + + if cpy.Comment != nil { + walkToReplace(v, cpy.Comment, func(r ast.Node) { + cpy.Comment = r.(*ast.CommentGroup) + }) + } + + replace(&cpy) + + case *ast.TypeSpec: + cpy := *n + + if cpy.Doc != nil { + walkToReplace(v, cpy.Doc, func(r ast.Node) { + cpy.Doc = r.(*ast.CommentGroup) + }) + } + walkToReplace(v, cpy.Name, func(r ast.Node) { + cpy.Name = r.(*ast.Ident) + }) + walkToReplace(v, cpy.Type, func(r ast.Node) { + cpy.Type = r.(ast.Expr) + }) + if cpy.Comment != nil { + walkToReplace(v, cpy.Comment, func(r ast.Node) { + cpy.Comment = r.(*ast.CommentGroup) + }) + } + + replace(&cpy) + case *ast.GenDecl: + cpy := *n + if n.Specs != nil { + cpy.Specs = make([]ast.Spec, len(n.Specs)) + copy(cpy.Specs, n.Specs) + } + + if cpy.Doc != nil { + walkToReplace(v, cpy.Doc, func(r ast.Node) { + cpy.Doc = r.(*ast.CommentGroup) + }) + } + for i, s := range cpy.Specs { + walkToReplace(v, s, func(r ast.Node) { + cpy.Specs[i] = r.(ast.Spec) + }) + } + + replace(&cpy) + case *ast.FuncDecl: + cpy := *n + + if cpy.Doc != nil { + walkToReplace(v, cpy.Doc, func(r ast.Node) { + cpy.Doc = r.(*ast.CommentGroup) + }) + } + if cpy.Recv != nil { + walkToReplace(v, cpy.Recv, func(r ast.Node) { + cpy.Recv = r.(*ast.FieldList) + }) + } + walkToReplace(v, cpy.Name, func(r ast.Node) { + cpy.Name = r.(*ast.Ident) + }) + walkToReplace(v, cpy.Type, func(r ast.Node) { + cpy.Type = r.(*ast.FuncType) + }) + if cpy.Body != nil { + walkToReplace(v, cpy.Body, func(r ast.Node) { + cpy.Body = r.(*ast.BlockStmt) + }) + } + + // Files and packages + replace(&cpy) + case *ast.File: + cpy := *n + + if cpy.Doc != nil { + walkToReplace(v, cpy.Doc, func(r ast.Node) { + cpy.Doc = r.(*ast.CommentGroup) + }) + } + walkToReplace(v, cpy.Name, func(r ast.Node) { + cpy.Name = r.(*ast.Ident) + }) + walkDeclList(v, cpy.Decls) + // don't walk cpy.Comments - they have been + // visited already through the individual + // nodes + + replace(&cpy) + case *ast.Package: + cpy := *n + cpy.Files = map[string]*ast.File{} + + for i, f := range n.Files { + cpy.Files[i] = f + walkToReplace(v, f, func(r ast.Node) { + cpy.Files[i] = r.(*ast.File) + }) + } + replace(&cpy) + + default: + panic(fmt.Sprintf("walkToReplace: unexpected node type %T", n)) + } + + if v != nil { + v.Visit(nil, func(ast.Node) { panic("can't replace the go-up nil") }) + } +} diff --git a/cmd/kitgen/sourcecontext.go b/cmd/kitgen/sourcecontext.go new file mode 100644 index 000000000..35933a20f --- /dev/null +++ b/cmd/kitgen/sourcecontext.go @@ -0,0 +1,53 @@ +package main + +import ( + "fmt" + "go/ast" + "go/token" +) + +type sourceContext struct { + pkg *ast.Ident + imports []*ast.ImportSpec + interfaces []iface + types []*ast.TypeSpec +} + +func (sc *sourceContext) validate() error { + if len(sc.interfaces) != 1 { + return fmt.Errorf("found %d interfaces, expecting exactly 1", len(sc.interfaces)) + } + for _, i := range sc.interfaces { + for _, m := range i.methods { + if len(m.results) < 1 { + return fmt.Errorf("method %q of interface %q has no result types", m.name, i.name) + } + } + } + return nil +} + +func (sc *sourceContext) importDecls() (decls []ast.Decl) { + have := map[string]struct{}{} + notHave := func(is *ast.ImportSpec) bool { + if _, has := have[is.Path.Value]; has { + return false + } + have[is.Path.Value] = struct{}{} + return true + } + + for _, is := range sc.imports { + if notHave(is) { + decls = append(decls, importFor(is)) + } + } + + for _, is := range fetchImports() { + if notHave(is) { + decls = append(decls, &ast.GenDecl{Tok: token.IMPORT, Specs: []ast.Spec{is}}) + } + } + + return +} diff --git a/cmd/kitgen/templates/full.go b/cmd/kitgen/templates/full.go new file mode 100644 index 000000000..c3516856a --- /dev/null +++ b/cmd/kitgen/templates/full.go @@ -0,0 +1,60 @@ +package foo + +import ( + "context" + "encoding/json" + "errors" + "net/http" + + "github.com/go-kit/kit/endpoint" + httptransport "github.com/go-kit/kit/transport/http" +) + +type ExampleService struct { +} + +type ExampleRequest struct { + I int + S string +} +type ExampleResponse struct { + S string + Err error +} + +type Endpoints struct { + ExampleEndpoint endpoint.Endpoint +} + +func (f ExampleService) ExampleEndpoint(ctx context.Context, i int, s string) (string, error) { + panic(errors.New("not implemented")) +} + +func makeExampleEndpoint(f ExampleService) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(ExampleRequest) + s, err := f.ExampleEndpoint(ctx, req.I, req.S) + return ExampleResponse{S: s, Err: err}, nil + } +} + +func inlineHandlerBuilder(m *http.ServeMux, endpoints Endpoints) { + m.Handle("/bar", httptransport.NewServer(endpoints.ExampleEndpoint, DecodeExampleRequest, EncodeExampleResponse)) +} + +func NewHTTPHandler(endpoints Endpoints) http.Handler { + m := http.NewServeMux() + inlineHandlerBuilder(m, endpoints) + return m +} + +func DecodeExampleRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req ExampleRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} + +func EncodeExampleResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} diff --git a/cmd/kitgen/testdata/anonfields/default/endpoints/endpoints.go b/cmd/kitgen/testdata/anonfields/default/endpoints/endpoints.go new file mode 100644 index 000000000..b6902de65 --- /dev/null +++ b/cmd/kitgen/testdata/anonfields/default/endpoints/endpoints.go @@ -0,0 +1,28 @@ +package endpoints + +import "context" + +import "github.com/go-kit/kit/endpoint" + +import "github.com/go-kit/kit/cmd/kitgen/testdata/anonfields/default/service" + +type FooRequest struct { + I int + S string +} +type FooResponse struct { + I int + Err error +} + +func makeFooEndpoint(s service.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(FooRequest) + i, err := s.Foo(ctx, req.I, req.S) + return FooResponse{I: i, Err: err}, nil + } +} + +type Endpoints struct { + Foo endpoint.Endpoint +} diff --git a/cmd/kitgen/testdata/anonfields/default/http/http.go b/cmd/kitgen/testdata/anonfields/default/http/http.go new file mode 100644 index 000000000..e02944084 --- /dev/null +++ b/cmd/kitgen/testdata/anonfields/default/http/http.go @@ -0,0 +1,24 @@ +package http + +import "context" +import "encoding/json" + +import "net/http" + +import httptransport "github.com/go-kit/kit/transport/http" +import "github.com/go-kit/kit/cmd/kitgen/testdata/anonfields/default/endpoints" + +func NewHTTPHandler(endpoints endpoints.Endpoints) http.Handler { + m := http.NewServeMux() + m.Handle("/foo", httptransport.NewServer(endpoints.Foo, DecodeFooRequest, EncodeFooResponse)) + return m +} +func DecodeFooRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req endpoints.FooRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodeFooResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} diff --git a/cmd/kitgen/testdata/anonfields/default/service/service.go b/cmd/kitgen/testdata/anonfields/default/service/service.go new file mode 100644 index 000000000..8adbd5a14 --- /dev/null +++ b/cmd/kitgen/testdata/anonfields/default/service/service.go @@ -0,0 +1,12 @@ +package service + +import "context" + +import "errors" + +type Service struct { +} + +func (s Service) Foo(ctx context.Context, i int, string1 string) (int, error) { + panic(errors.New("not implemented")) +} diff --git a/cmd/kitgen/testdata/anonfields/flat/gokit.go b/cmd/kitgen/testdata/anonfields/flat/gokit.go new file mode 100644 index 000000000..f19d2b275 --- /dev/null +++ b/cmd/kitgen/testdata/anonfields/flat/gokit.go @@ -0,0 +1,51 @@ +package foo + +import "context" +import "encoding/json" +import "errors" +import "net/http" +import "github.com/go-kit/kit/endpoint" +import httptransport "github.com/go-kit/kit/transport/http" + +type Service struct { +} + +func (s Service) Foo(ctx context.Context, i int, string1 string) (int, error) { + panic(errors.New("not implemented")) +} + +type FooRequest struct { + I int + S string +} +type FooResponse struct { + I int + Err error +} + +func makeFooEndpoint(s Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(FooRequest) + i, err := s.Foo(ctx, req.I, req.S) + return FooResponse{I: i, Err: err}, nil + } +} + +type Endpoints struct { + Foo endpoint.Endpoint +} + +func NewHTTPHandler(endpoints Endpoints) http.Handler { + m := http.NewServeMux() + m.Handle("/foo", httptransport.NewServer(endpoints.Foo, DecodeFooRequest, EncodeFooResponse)) + return m +} +func DecodeFooRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req FooRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodeFooResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} diff --git a/cmd/kitgen/testdata/anonfields/in.go b/cmd/kitgen/testdata/anonfields/in.go new file mode 100644 index 000000000..c0c87d808 --- /dev/null +++ b/cmd/kitgen/testdata/anonfields/in.go @@ -0,0 +1,6 @@ +package foo + +// from https://github.com/go-kit/kit/pull/589#issuecomment-319937530 +type Service interface { + Foo(context.Context, int, string) (int, error) +} diff --git a/cmd/kitgen/testdata/foo/default/endpoints/endpoints.go b/cmd/kitgen/testdata/foo/default/endpoints/endpoints.go new file mode 100644 index 000000000..ff8ef0184 --- /dev/null +++ b/cmd/kitgen/testdata/foo/default/endpoints/endpoints.go @@ -0,0 +1,28 @@ +package endpoints + +import "context" + +import "github.com/go-kit/kit/endpoint" + +import "github.com/go-kit/kit/cmd/kitgen/testdata/foo/default/service" + +type BarRequest struct { + I int + S string +} +type BarResponse struct { + S string + Err error +} + +func makeBarEndpoint(f service.FooService) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(BarRequest) + s, err := f.Bar(ctx, req.I, req.S) + return BarResponse{S: s, Err: err}, nil + } +} + +type Endpoints struct { + Bar endpoint.Endpoint +} diff --git a/cmd/kitgen/testdata/foo/default/http/http.go b/cmd/kitgen/testdata/foo/default/http/http.go new file mode 100644 index 000000000..286926306 --- /dev/null +++ b/cmd/kitgen/testdata/foo/default/http/http.go @@ -0,0 +1,24 @@ +package http + +import "context" +import "encoding/json" + +import "net/http" + +import httptransport "github.com/go-kit/kit/transport/http" +import "github.com/go-kit/kit/cmd/kitgen/testdata/foo/default/endpoints" + +func NewHTTPHandler(endpoints endpoints.Endpoints) http.Handler { + m := http.NewServeMux() + m.Handle("/bar", httptransport.NewServer(endpoints.Bar, DecodeBarRequest, EncodeBarResponse)) + return m +} +func DecodeBarRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req endpoints.BarRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodeBarResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} diff --git a/cmd/kitgen/testdata/foo/default/service/service.go b/cmd/kitgen/testdata/foo/default/service/service.go new file mode 100644 index 000000000..02a7babea --- /dev/null +++ b/cmd/kitgen/testdata/foo/default/service/service.go @@ -0,0 +1,12 @@ +package service + +import "context" + +import "errors" + +type FooService struct { +} + +func (f FooService) Bar(ctx context.Context, i int, s string) (string, error) { + panic(errors.New("not implemented")) +} diff --git a/cmd/kitgen/testdata/foo/flat/gokit.go b/cmd/kitgen/testdata/foo/flat/gokit.go new file mode 100644 index 000000000..9e0bc1f9b --- /dev/null +++ b/cmd/kitgen/testdata/foo/flat/gokit.go @@ -0,0 +1,51 @@ +package foo + +import "context" +import "encoding/json" +import "errors" +import "net/http" +import "github.com/go-kit/kit/endpoint" +import httptransport "github.com/go-kit/kit/transport/http" + +type FooService struct { +} + +func (f FooService) Bar(ctx context.Context, i int, s string) (string, error) { + panic(errors.New("not implemented")) +} + +type BarRequest struct { + I int + S string +} +type BarResponse struct { + S string + Err error +} + +func makeBarEndpoint(f FooService) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(BarRequest) + s, err := f.Bar(ctx, req.I, req.S) + return BarResponse{S: s, Err: err}, nil + } +} + +type Endpoints struct { + Bar endpoint.Endpoint +} + +func NewHTTPHandler(endpoints Endpoints) http.Handler { + m := http.NewServeMux() + m.Handle("/bar", httptransport.NewServer(endpoints.Bar, DecodeBarRequest, EncodeBarResponse)) + return m +} +func DecodeBarRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req BarRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodeBarResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} diff --git a/cmd/kitgen/testdata/foo/in.go b/cmd/kitgen/testdata/foo/in.go new file mode 100644 index 000000000..1c01933fa --- /dev/null +++ b/cmd/kitgen/testdata/foo/in.go @@ -0,0 +1,5 @@ +package foo + +type FooService interface { + Bar(ctx context.Context, i int, s string) (string, error) +} diff --git a/cmd/kitgen/testdata/profilesvc/default/endpoints/endpoints.go b/cmd/kitgen/testdata/profilesvc/default/endpoints/endpoints.go new file mode 100644 index 000000000..892a1e1ff --- /dev/null +++ b/cmd/kitgen/testdata/profilesvc/default/endpoints/endpoints.go @@ -0,0 +1,162 @@ +package endpoints + +import "context" + +import "github.com/go-kit/kit/endpoint" + +import "github.com/go-kit/kit/cmd/kitgen/testdata/profilesvc/default/service" + +type PostProfileRequest struct { + P service.Profile +} +type PostProfileResponse struct { + Err error +} + +func makePostProfileEndpoint(s service.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(PostProfileRequest) + err := s.PostProfile(ctx, req.P) + return PostProfileResponse{Err: err}, nil + } +} + +type GetProfileRequest struct { + Id string +} +type GetProfileResponse struct { + P service.Profile + Err error +} + +func makeGetProfileEndpoint(s service.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(GetProfileRequest) + P, err := s.GetProfile(ctx, req.Id) + return GetProfileResponse{P: P, Err: err}, nil + } +} + +type PutProfileRequest struct { + Id string + P service.Profile +} +type PutProfileResponse struct { + Err error +} + +func makePutProfileEndpoint(s service.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(PutProfileRequest) + err := s.PutProfile(ctx, req.Id, req.P) + return PutProfileResponse{Err: err}, nil + } +} + +type PatchProfileRequest struct { + Id string + P service.Profile +} +type PatchProfileResponse struct { + Err error +} + +func makePatchProfileEndpoint(s service.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(PatchProfileRequest) + err := s.PatchProfile(ctx, req.Id, req.P) + return PatchProfileResponse{Err: err}, nil + } +} + +type DeleteProfileRequest struct { + Id string +} +type DeleteProfileResponse struct { + Err error +} + +func makeDeleteProfileEndpoint(s service.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(DeleteProfileRequest) + err := s.DeleteProfile(ctx, req.Id) + return DeleteProfileResponse{Err: err}, nil + } +} + +type GetAddressesRequest struct { + ProfileID string +} +type GetAddressesResponse struct { + S []service.Address + Err error +} + +func makeGetAddressesEndpoint(s service.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(GetAddressesRequest) + slice1, err := s.GetAddresses(ctx, req.ProfileID) + return GetAddressesResponse{S: slice1, Err: err}, nil + } +} + +type GetAddressRequest struct { + ProfileID string + AddressID string +} +type GetAddressResponse struct { + A service.Address + Err error +} + +func makeGetAddressEndpoint(s service.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(GetAddressRequest) + A, err := s.GetAddress(ctx, req.ProfileID, req.AddressID) + return GetAddressResponse{A: A, Err: err}, nil + } +} + +type PostAddressRequest struct { + ProfileID string + A service.Address +} +type PostAddressResponse struct { + Err error +} + +func makePostAddressEndpoint(s service.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(PostAddressRequest) + err := s.PostAddress(ctx, req.ProfileID, req.A) + return PostAddressResponse{Err: err}, nil + } +} + +type DeleteAddressRequest struct { + ProfileID string + AddressID string +} +type DeleteAddressResponse struct { + Err error +} + +func makeDeleteAddressEndpoint(s service.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(DeleteAddressRequest) + err := s.DeleteAddress(ctx, req.ProfileID, req.AddressID) + return DeleteAddressResponse{Err: err}, nil + } +} + +type Endpoints struct { + PostProfile endpoint.Endpoint + GetProfile endpoint.Endpoint + PutProfile endpoint.Endpoint + PatchProfile endpoint.Endpoint + DeleteProfile endpoint.Endpoint + GetAddresses endpoint.Endpoint + GetAddress endpoint.Endpoint + PostAddress endpoint.Endpoint + DeleteAddress endpoint.Endpoint +} diff --git a/cmd/kitgen/testdata/profilesvc/default/http/http.go b/cmd/kitgen/testdata/profilesvc/default/http/http.go new file mode 100644 index 000000000..b59f762b8 --- /dev/null +++ b/cmd/kitgen/testdata/profilesvc/default/http/http.go @@ -0,0 +1,104 @@ +package http + +import "context" +import "encoding/json" + +import "net/http" + +import httptransport "github.com/go-kit/kit/transport/http" +import "github.com/go-kit/kit/cmd/kitgen/testdata/profilesvc/default/endpoints" + +func NewHTTPHandler(endpoints endpoints.Endpoints) http.Handler { + m := http.NewServeMux() + m.Handle("/postprofile", httptransport.NewServer(endpoints.PostProfile, DecodePostProfileRequest, EncodePostProfileResponse)) + m.Handle("/getprofile", httptransport.NewServer(endpoints.GetProfile, DecodeGetProfileRequest, EncodeGetProfileResponse)) + m.Handle("/putprofile", httptransport.NewServer(endpoints.PutProfile, DecodePutProfileRequest, EncodePutProfileResponse)) + m.Handle("/patchprofile", httptransport.NewServer(endpoints.PatchProfile, DecodePatchProfileRequest, EncodePatchProfileResponse)) + m.Handle("/deleteprofile", httptransport.NewServer(endpoints.DeleteProfile, DecodeDeleteProfileRequest, EncodeDeleteProfileResponse)) + m.Handle("/getaddresses", httptransport.NewServer(endpoints.GetAddresses, DecodeGetAddressesRequest, EncodeGetAddressesResponse)) + m.Handle("/getaddress", httptransport.NewServer(endpoints.GetAddress, DecodeGetAddressRequest, EncodeGetAddressResponse)) + m.Handle("/postaddress", httptransport.NewServer(endpoints.PostAddress, DecodePostAddressRequest, EncodePostAddressResponse)) + m.Handle("/deleteaddress", httptransport.NewServer(endpoints.DeleteAddress, DecodeDeleteAddressRequest, EncodeDeleteAddressResponse)) + return m +} +func DecodePostProfileRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req endpoints.PostProfileRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodePostProfileResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} +func DecodeGetProfileRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req endpoints.GetProfileRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodeGetProfileResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} +func DecodePutProfileRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req endpoints.PutProfileRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodePutProfileResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} +func DecodePatchProfileRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req endpoints.PatchProfileRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodePatchProfileResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} +func DecodeDeleteProfileRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req endpoints.DeleteProfileRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodeDeleteProfileResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} +func DecodeGetAddressesRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req endpoints.GetAddressesRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodeGetAddressesResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} +func DecodeGetAddressRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req endpoints.GetAddressRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodeGetAddressResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} +func DecodePostAddressRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req endpoints.PostAddressRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodePostAddressResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} +func DecodeDeleteAddressRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req endpoints.DeleteAddressRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodeDeleteAddressResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} diff --git a/cmd/kitgen/testdata/profilesvc/default/service/service.go b/cmd/kitgen/testdata/profilesvc/default/service/service.go new file mode 100644 index 000000000..42d1b4d2e --- /dev/null +++ b/cmd/kitgen/testdata/profilesvc/default/service/service.go @@ -0,0 +1,45 @@ +package service + +import "context" + +import "errors" + +type Profile struct { + ID string `json:"id"` + Name string `json:"name,omitempty"` + Addresses []Address `json:"addresses,omitempty"` +} +type Address struct { + ID string `json:"id"` + Location string `json:"location,omitempty"` +} +type Service struct { +} + +func (s Service) PostProfile(ctx context.Context, p Profile) error { + panic(errors.New("not implemented")) +} +func (s Service) GetProfile(ctx context.Context, id string) (Profile, error) { + panic(errors.New("not implemented")) +} +func (s Service) PutProfile(ctx context.Context, id string, p Profile) error { + panic(errors.New("not implemented")) +} +func (s Service) PatchProfile(ctx context.Context, id string, p Profile) error { + panic(errors.New("not implemented")) +} +func (s Service) DeleteProfile(ctx context.Context, id string) error { + panic(errors.New("not implemented")) +} +func (s Service) GetAddresses(ctx context.Context, profileID string) ([]Address, error) { + panic(errors.New("not implemented")) +} +func (s Service) GetAddress(ctx context.Context, profileID string, addressID string) (Address, error) { + panic(errors.New("not implemented")) +} +func (s Service) PostAddress(ctx context.Context, profileID string, a Address) error { + panic(errors.New("not implemented")) +} +func (s Service) DeleteAddress(ctx context.Context, profileID string, addressID string) error { + panic(errors.New("not implemented")) +} diff --git a/cmd/kitgen/testdata/profilesvc/flat/gokit.go b/cmd/kitgen/testdata/profilesvc/flat/gokit.go new file mode 100644 index 000000000..10fb436d1 --- /dev/null +++ b/cmd/kitgen/testdata/profilesvc/flat/gokit.go @@ -0,0 +1,298 @@ +package profilesvc + +import "context" +import "encoding/json" +import "errors" +import "net/http" +import "github.com/go-kit/kit/endpoint" +import httptransport "github.com/go-kit/kit/transport/http" + +type Profile struct { + ID string `json:"id"` + Name string `json:"name,omitempty"` + Addresses []Address `json:"addresses,omitempty"` +} +type Address struct { + ID string `json:"id"` + Location string `json:"location,omitempty"` +} +type Service struct { +} + +func (s Service) PostProfile(ctx context.Context, p Profile) error { + panic(errors.New("not implemented")) +} + +type PostProfileRequest struct { + P Profile +} +type PostProfileResponse struct { + Err error +} + +func makePostProfileEndpoint(s Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(PostProfileRequest) + err := s.PostProfile(ctx, req.P) + return PostProfileResponse{Err: err}, nil + } +} +func (s Service) GetProfile(ctx context.Context, id string) (Profile, error) { + panic(errors.New("not implemented")) +} + +type GetProfileRequest struct { + Id string +} +type GetProfileResponse struct { + P Profile + Err error +} + +func makeGetProfileEndpoint(s Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(GetProfileRequest) + P, err := s.GetProfile(ctx, req.Id) + return GetProfileResponse{P: P, Err: err}, nil + } +} +func (s Service) PutProfile(ctx context.Context, id string, p Profile) error { + panic(errors.New("not implemented")) +} + +type PutProfileRequest struct { + Id string + P Profile +} +type PutProfileResponse struct { + Err error +} + +func makePutProfileEndpoint(s Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(PutProfileRequest) + err := s.PutProfile(ctx, req.Id, req.P) + return PutProfileResponse{Err: err}, nil + } +} +func (s Service) PatchProfile(ctx context.Context, id string, p Profile) error { + panic(errors.New("not implemented")) +} + +type PatchProfileRequest struct { + Id string + P Profile +} +type PatchProfileResponse struct { + Err error +} + +func makePatchProfileEndpoint(s Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(PatchProfileRequest) + err := s.PatchProfile(ctx, req.Id, req.P) + return PatchProfileResponse{Err: err}, nil + } +} +func (s Service) DeleteProfile(ctx context.Context, id string) error { + panic(errors.New("not implemented")) +} + +type DeleteProfileRequest struct { + Id string +} +type DeleteProfileResponse struct { + Err error +} + +func makeDeleteProfileEndpoint(s Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(DeleteProfileRequest) + err := s.DeleteProfile(ctx, req.Id) + return DeleteProfileResponse{Err: err}, nil + } +} +func (s Service) GetAddresses(ctx context.Context, profileID string) ([]Address, error) { + panic(errors.New("not implemented")) +} + +type GetAddressesRequest struct { + ProfileID string +} +type GetAddressesResponse struct { + S []Address + Err error +} + +func makeGetAddressesEndpoint(s Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(GetAddressesRequest) + slice1, err := s.GetAddresses(ctx, req.ProfileID) + return GetAddressesResponse{S: slice1, Err: err}, nil + } +} +func (s Service) GetAddress(ctx context.Context, profileID string, addressID string) (Address, error) { + panic(errors.New("not implemented")) +} + +type GetAddressRequest struct { + ProfileID string + AddressID string +} +type GetAddressResponse struct { + A Address + Err error +} + +func makeGetAddressEndpoint(s Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(GetAddressRequest) + A, err := s.GetAddress(ctx, req.ProfileID, req.AddressID) + return GetAddressResponse{A: A, Err: err}, nil + } +} +func (s Service) PostAddress(ctx context.Context, profileID string, a Address) error { + panic(errors.New("not implemented")) +} + +type PostAddressRequest struct { + ProfileID string + A Address +} +type PostAddressResponse struct { + Err error +} + +func makePostAddressEndpoint(s Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(PostAddressRequest) + err := s.PostAddress(ctx, req.ProfileID, req.A) + return PostAddressResponse{Err: err}, nil + } +} +func (s Service) DeleteAddress(ctx context.Context, profileID string, addressID string) error { + panic(errors.New("not implemented")) +} + +type DeleteAddressRequest struct { + ProfileID string + AddressID string +} +type DeleteAddressResponse struct { + Err error +} + +func makeDeleteAddressEndpoint(s Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(DeleteAddressRequest) + err := s.DeleteAddress(ctx, req.ProfileID, req.AddressID) + return DeleteAddressResponse{Err: err}, nil + } +} + +type Endpoints struct { + PostProfile endpoint.Endpoint + GetProfile endpoint.Endpoint + PutProfile endpoint.Endpoint + PatchProfile endpoint.Endpoint + DeleteProfile endpoint.Endpoint + GetAddresses endpoint.Endpoint + GetAddress endpoint.Endpoint + PostAddress endpoint.Endpoint + DeleteAddress endpoint.Endpoint +} + +func NewHTTPHandler(endpoints Endpoints) http.Handler { + m := http.NewServeMux() + m.Handle("/postprofile", httptransport.NewServer(endpoints.PostProfile, DecodePostProfileRequest, EncodePostProfileResponse)) + m.Handle("/getprofile", httptransport.NewServer(endpoints.GetProfile, DecodeGetProfileRequest, EncodeGetProfileResponse)) + m.Handle("/putprofile", httptransport.NewServer(endpoints.PutProfile, DecodePutProfileRequest, EncodePutProfileResponse)) + m.Handle("/patchprofile", httptransport.NewServer(endpoints.PatchProfile, DecodePatchProfileRequest, EncodePatchProfileResponse)) + m.Handle("/deleteprofile", httptransport.NewServer(endpoints.DeleteProfile, DecodeDeleteProfileRequest, EncodeDeleteProfileResponse)) + m.Handle("/getaddresses", httptransport.NewServer(endpoints.GetAddresses, DecodeGetAddressesRequest, EncodeGetAddressesResponse)) + m.Handle("/getaddress", httptransport.NewServer(endpoints.GetAddress, DecodeGetAddressRequest, EncodeGetAddressResponse)) + m.Handle("/postaddress", httptransport.NewServer(endpoints.PostAddress, DecodePostAddressRequest, EncodePostAddressResponse)) + m.Handle("/deleteaddress", httptransport.NewServer(endpoints.DeleteAddress, DecodeDeleteAddressRequest, EncodeDeleteAddressResponse)) + return m +} +func DecodePostProfileRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req PostProfileRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodePostProfileResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} +func DecodeGetProfileRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req GetProfileRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodeGetProfileResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} +func DecodePutProfileRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req PutProfileRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodePutProfileResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} +func DecodePatchProfileRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req PatchProfileRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodePatchProfileResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} +func DecodeDeleteProfileRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req DeleteProfileRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodeDeleteProfileResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} +func DecodeGetAddressesRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req GetAddressesRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodeGetAddressesResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} +func DecodeGetAddressRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req GetAddressRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodeGetAddressResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} +func DecodePostAddressRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req PostAddressRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodePostAddressResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} +func DecodeDeleteAddressRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req DeleteAddressRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodeDeleteAddressResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} diff --git a/cmd/kitgen/testdata/profilesvc/in.go b/cmd/kitgen/testdata/profilesvc/in.go new file mode 100644 index 000000000..208fed9b7 --- /dev/null +++ b/cmd/kitgen/testdata/profilesvc/in.go @@ -0,0 +1,24 @@ +package profilesvc + +type Service interface { + PostProfile(ctx context.Context, p Profile) error + GetProfile(ctx context.Context, id string) (Profile, error) + PutProfile(ctx context.Context, id string, p Profile) error + PatchProfile(ctx context.Context, id string, p Profile) error + DeleteProfile(ctx context.Context, id string) error + GetAddresses(ctx context.Context, profileID string) ([]Address, error) + GetAddress(ctx context.Context, profileID string, addressID string) (Address, error) + PostAddress(ctx context.Context, profileID string, a Address) error + DeleteAddress(ctx context.Context, profileID string, addressID string) error +} + +type Profile struct { + ID string `json:"id"` + Name string `json:"name,omitempty"` + Addresses []Address `json:"addresses,omitempty"` +} + +type Address struct { + ID string `json:"id"` + Location string `json:"location,omitempty"` +} diff --git a/cmd/kitgen/testdata/stringservice/default/endpoints/endpoints.go b/cmd/kitgen/testdata/stringservice/default/endpoints/endpoints.go new file mode 100644 index 000000000..b3386aaeb --- /dev/null +++ b/cmd/kitgen/testdata/stringservice/default/endpoints/endpoints.go @@ -0,0 +1,44 @@ +package endpoints + +import "context" + +import "github.com/go-kit/kit/endpoint" + +import "github.com/go-kit/kit/cmd/kitgen/testdata/stringservice/default/service" + +type ConcatRequest struct { + A string + B string +} +type ConcatResponse struct { + S string + Err error +} + +func makeConcatEndpoint(s service.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(ConcatRequest) + string1, err := s.Concat(ctx, req.A, req.B) + return ConcatResponse{S: string1, Err: err}, nil + } +} + +type CountRequest struct { + S string +} +type CountResponse struct { + Count int +} + +func makeCountEndpoint(s service.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(CountRequest) + count := s.Count(ctx, req.S) + return CountResponse{Count: count}, nil + } +} + +type Endpoints struct { + Concat endpoint.Endpoint + Count endpoint.Endpoint +} diff --git a/cmd/kitgen/testdata/stringservice/default/http/http.go b/cmd/kitgen/testdata/stringservice/default/http/http.go new file mode 100644 index 000000000..31e2c1938 --- /dev/null +++ b/cmd/kitgen/testdata/stringservice/default/http/http.go @@ -0,0 +1,34 @@ +package http + +import "context" +import "encoding/json" + +import "net/http" + +import httptransport "github.com/go-kit/kit/transport/http" +import "github.com/go-kit/kit/cmd/kitgen/testdata/stringservice/default/endpoints" + +func NewHTTPHandler(endpoints endpoints.Endpoints) http.Handler { + m := http.NewServeMux() + m.Handle("/concat", httptransport.NewServer(endpoints.Concat, DecodeConcatRequest, EncodeConcatResponse)) + m.Handle("/count", httptransport.NewServer(endpoints.Count, DecodeCountRequest, EncodeCountResponse)) + return m +} +func DecodeConcatRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req endpoints.ConcatRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodeConcatResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} +func DecodeCountRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req endpoints.CountRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodeCountResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} diff --git a/cmd/kitgen/testdata/stringservice/default/service/service.go b/cmd/kitgen/testdata/stringservice/default/service/service.go new file mode 100644 index 000000000..ddf24f972 --- /dev/null +++ b/cmd/kitgen/testdata/stringservice/default/service/service.go @@ -0,0 +1,15 @@ +package service + +import "context" + +import "errors" + +type Service struct { +} + +func (s Service) Concat(ctx context.Context, a string, b string) (string, error) { + panic(errors.New("not implemented")) +} +func (s Service) Count(ctx context.Context, string1 string) int { + panic(errors.New("not implemented")) +} diff --git a/cmd/kitgen/testdata/stringservice/flat/gokit.go b/cmd/kitgen/testdata/stringservice/flat/gokit.go new file mode 100644 index 000000000..788b8b956 --- /dev/null +++ b/cmd/kitgen/testdata/stringservice/flat/gokit.go @@ -0,0 +1,80 @@ +package foo + +import "context" +import "encoding/json" +import "errors" +import "net/http" +import "github.com/go-kit/kit/endpoint" +import httptransport "github.com/go-kit/kit/transport/http" + +type Service struct { +} + +func (s Service) Concat(ctx context.Context, a string, b string) (string, error) { + panic(errors.New("not implemented")) +} + +type ConcatRequest struct { + A string + B string +} +type ConcatResponse struct { + S string + Err error +} + +func makeConcatEndpoint(s Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(ConcatRequest) + string1, err := s.Concat(ctx, req.A, req.B) + return ConcatResponse{S: string1, Err: err}, nil + } +} +func (s Service) Count(ctx context.Context, string1 string) int { + panic(errors.New("not implemented")) +} + +type CountRequest struct { + S string +} +type CountResponse struct { + Count int +} + +func makeCountEndpoint(s Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(CountRequest) + count := s.Count(ctx, req.S) + return CountResponse{Count: count}, nil + } +} + +type Endpoints struct { + Concat endpoint.Endpoint + Count endpoint.Endpoint +} + +func NewHTTPHandler(endpoints Endpoints) http.Handler { + m := http.NewServeMux() + m.Handle("/concat", httptransport.NewServer(endpoints.Concat, DecodeConcatRequest, EncodeConcatResponse)) + m.Handle("/count", httptransport.NewServer(endpoints.Count, DecodeCountRequest, EncodeCountResponse)) + return m +} +func DecodeConcatRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req ConcatRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodeConcatResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} +func DecodeCountRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req CountRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodeCountResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} diff --git a/cmd/kitgen/testdata/stringservice/in.go b/cmd/kitgen/testdata/stringservice/in.go new file mode 100644 index 000000000..68f018752 --- /dev/null +++ b/cmd/kitgen/testdata/stringservice/in.go @@ -0,0 +1,6 @@ +package foo + +type Service interface { + Concat(ctx context.Context, a, b string) (string, error) + Count(ctx context.Context, s string) (count int) +} diff --git a/cmd/kitgen/testdata/underscores/default/endpoints/endpoints.go b/cmd/kitgen/testdata/underscores/default/endpoints/endpoints.go new file mode 100644 index 000000000..b36f63b1c --- /dev/null +++ b/cmd/kitgen/testdata/underscores/default/endpoints/endpoints.go @@ -0,0 +1,27 @@ +package endpoints + +import "context" + +import "github.com/go-kit/kit/endpoint" + +import "github.com/go-kit/kit/cmd/kitgen/testdata/underscores/default/service" + +type FooRequest struct { + I int +} +type FooResponse struct { + I int + Err error +} + +func makeFooEndpoint(s service.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(FooRequest) + i, err := s.Foo(ctx, req.I) + return FooResponse{I: i, Err: err}, nil + } +} + +type Endpoints struct { + Foo endpoint.Endpoint +} diff --git a/cmd/kitgen/testdata/underscores/default/http/http.go b/cmd/kitgen/testdata/underscores/default/http/http.go new file mode 100644 index 000000000..a2844f048 --- /dev/null +++ b/cmd/kitgen/testdata/underscores/default/http/http.go @@ -0,0 +1,24 @@ +package http + +import "context" +import "encoding/json" + +import "net/http" + +import httptransport "github.com/go-kit/kit/transport/http" +import "github.com/go-kit/kit/cmd/kitgen/testdata/underscores/default/endpoints" + +func NewHTTPHandler(endpoints endpoints.Endpoints) http.Handler { + m := http.NewServeMux() + m.Handle("/foo", httptransport.NewServer(endpoints.Foo, DecodeFooRequest, EncodeFooResponse)) + return m +} +func DecodeFooRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req endpoints.FooRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodeFooResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} diff --git a/cmd/kitgen/testdata/underscores/default/service/service.go b/cmd/kitgen/testdata/underscores/default/service/service.go new file mode 100644 index 000000000..e249a490f --- /dev/null +++ b/cmd/kitgen/testdata/underscores/default/service/service.go @@ -0,0 +1,12 @@ +package service + +import "context" + +import "errors" + +type Service struct { +} + +func (s Service) Foo(ctx context.Context, i int) (int, error) { + panic(errors.New("not implemented")) +} diff --git a/cmd/kitgen/testdata/underscores/flat/gokit.go b/cmd/kitgen/testdata/underscores/flat/gokit.go new file mode 100644 index 000000000..7f6a7da7f --- /dev/null +++ b/cmd/kitgen/testdata/underscores/flat/gokit.go @@ -0,0 +1,50 @@ +package underscores + +import "context" +import "encoding/json" +import "errors" +import "net/http" +import "github.com/go-kit/kit/endpoint" +import httptransport "github.com/go-kit/kit/transport/http" + +type Service struct { +} + +func (s Service) Foo(ctx context.Context, i int) (int, error) { + panic(errors.New("not implemented")) +} + +type FooRequest struct { + I int +} +type FooResponse struct { + I int + Err error +} + +func makeFooEndpoint(s Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(FooRequest) + i, err := s.Foo(ctx, req.I) + return FooResponse{I: i, Err: err}, nil + } +} + +type Endpoints struct { + Foo endpoint.Endpoint +} + +func NewHTTPHandler(endpoints Endpoints) http.Handler { + m := http.NewServeMux() + m.Handle("/foo", httptransport.NewServer(endpoints.Foo, DecodeFooRequest, EncodeFooResponse)) + return m +} +func DecodeFooRequest(_ context.Context, r *http.Request) (interface{}, error) { + var req FooRequest + err := json.NewDecoder(r.Body).Decode(&req) + return req, err +} +func EncodeFooResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + return json.NewEncoder(w).Encode(response) +} diff --git a/cmd/kitgen/testdata/underscores/in.go b/cmd/kitgen/testdata/underscores/in.go new file mode 100644 index 000000000..9457ee060 --- /dev/null +++ b/cmd/kitgen/testdata/underscores/in.go @@ -0,0 +1,7 @@ +package underscores + +import "context" + +type Service interface { + Foo(_ context.Context, _ int) (int, error) +} diff --git a/cmd/kitgen/transform.go b/cmd/kitgen/transform.go new file mode 100644 index 000000000..362398c92 --- /dev/null +++ b/cmd/kitgen/transform.go @@ -0,0 +1,230 @@ +package main + +import ( + "bytes" + "fmt" + "go/ast" + "go/format" + "go/token" + "io" + "os" + "path/filepath" + "sort" + "strings" + + "github.com/davecgh/go-spew/spew" + "github.com/pkg/errors" + + "golang.org/x/tools/imports" +) + +type ( + files map[string]io.Reader + layout interface { + transformAST(ctx *sourceContext) (files, error) + } + outputTree map[string]*ast.File +) + +func (ot outputTree) addFile(path, pkgname string) *ast.File { + file := &ast.File{ + Name: id(pkgname), + Decls: []ast.Decl{}, + } + ot[path] = file + return file +} + +func getGopath() string { + gopath, set := os.LookupEnv("GOPATH") + if !set { + return filepath.Join(os.Getenv("HOME"), "go") + } + return gopath +} + +func importPath(targetDir, gopath string) (string, error) { + if !filepath.IsAbs(targetDir) { + return "", fmt.Errorf("%q is not an absolute path", targetDir) + } + + for _, dir := range filepath.SplitList(gopath) { + abspath, err := filepath.Abs(dir) + if err != nil { + continue + } + srcPath := filepath.Join(abspath, "src") + + res, err := filepath.Rel(srcPath, targetDir) + if err != nil { + continue + } + if strings.Index(res, "..") == -1 { + return res, nil + } + } + return "", fmt.Errorf("%q is not in GOPATH (%s)", targetDir, gopath) + +} + +func selectify(file *ast.File, pkgName, identName, importPath string) *ast.File { + if file.Name.Name == pkgName { + return file + } + + selector := sel(id(pkgName), id(identName)) + var did bool + if file, did = selectifyIdent(identName, file, selector); did { + addImport(file, importPath) + } + return file +} + +type selIdentFn func(ast.Node, func(ast.Node)) Visitor + +func (f selIdentFn) Visit(node ast.Node, r func(ast.Node)) Visitor { + return f(node, r) +} + +func selectifyIdent(identName string, file *ast.File, selector ast.Expr) (*ast.File, bool) { + var replaced bool + var r selIdentFn + r = selIdentFn(func(node ast.Node, replaceWith func(ast.Node)) Visitor { + switch id := node.(type) { + case *ast.SelectorExpr: + return nil + case *ast.Ident: + if id.Name == identName { + replaced = true + replaceWith(selector) + } + } + return r + }) + return WalkReplace(r, file).(*ast.File), replaced +} + +func formatNode(fname string, node ast.Node) (*bytes.Buffer, error) { + if file, is := node.(*ast.File); is { + sort.Stable(sortableDecls(file.Decls)) + } + outfset := token.NewFileSet() + buf := &bytes.Buffer{} + err := format.Node(buf, outfset, node) + if err != nil { + return nil, err + } + imps, err := imports.Process(fname, buf.Bytes(), nil) + if err != nil { + return nil, err + } + return bytes.NewBuffer(imps), nil +} + +type sortableDecls []ast.Decl + +func (sd sortableDecls) Len() int { + return len(sd) +} + +func (sd sortableDecls) Less(i int, j int) bool { + switch left := sd[i].(type) { + case *ast.GenDecl: + switch right := sd[j].(type) { + default: + return left.Tok == token.IMPORT + case *ast.GenDecl: + return left.Tok == token.IMPORT && right.Tok != token.IMPORT + } + } + return false +} + +func (sd sortableDecls) Swap(i int, j int) { + sd[i], sd[j] = sd[j], sd[i] +} + +func formatNodes(nodes outputTree) (files, error) { + res := files{} + var err error + for fn, node := range nodes { + res[fn], err = formatNode(fn, node) + if err != nil { + return nil, errors.Wrapf(err, "formatNodes") + } + } + return res, nil +} + +// XXX debug +func spewDecls(f *ast.File) { + for _, d := range f.Decls { + switch dcl := d.(type) { + default: + spew.Dump(dcl) + case *ast.GenDecl: + spew.Dump(dcl.Tok) + case *ast.FuncDecl: + spew.Dump(dcl.Name.Name) + } + } +} + +func addImports(root *ast.File, ctx *sourceContext) { + root.Decls = append(root.Decls, ctx.importDecls()...) +} + +func addImport(root *ast.File, path string) { + for _, d := range root.Decls { + if imp, is := d.(*ast.GenDecl); is && imp.Tok == token.IMPORT { + for _, s := range imp.Specs { + if s.(*ast.ImportSpec).Path.Value == `"`+path+`"` { + return // already have one + // xxx aliased imports? + } + } + } + } + root.Decls = append(root.Decls, importFor(importSpec(path))) +} + +func addStubStruct(root *ast.File, iface iface) { + root.Decls = append(root.Decls, iface.stubStructDecl()) +} + +func addType(root *ast.File, typ *ast.TypeSpec) { + root.Decls = append(root.Decls, typeDecl(typ)) +} + +func addMethod(root *ast.File, iface iface, meth method) { + def := meth.definition(iface) + root.Decls = append(root.Decls, def) +} + +func addRequestStruct(root *ast.File, meth method) { + root.Decls = append(root.Decls, meth.requestStruct()) +} + +func addResponseStruct(root *ast.File, meth method) { + root.Decls = append(root.Decls, meth.responseStruct()) +} + +func addEndpointMaker(root *ast.File, ifc iface, meth method) { + root.Decls = append(root.Decls, meth.endpointMaker(ifc)) +} + +func addEndpointsStruct(root *ast.File, ifc iface) { + root.Decls = append(root.Decls, ifc.endpointsStruct()) +} + +func addHTTPHandler(root *ast.File, ifc iface) { + root.Decls = append(root.Decls, ifc.httpHandler()) +} + +func addDecoder(root *ast.File, meth method) { + root.Decls = append(root.Decls, meth.decoderFunc()) +} + +func addEncoder(root *ast.File, meth method) { + root.Decls = append(root.Decls, meth.encoderFunc()) +}