diff --git a/internal/ctrlflow/ctrlflow.go b/internal/ctrlflow/ctrlflow.go index d40b08eb..4b61dbe3 100644 --- a/internal/ctrlflow/ctrlflow.go +++ b/internal/ctrlflow/ctrlflow.go @@ -26,10 +26,15 @@ const ( defaultBlockSplits = 0 defaultJunkJumps = 0 defaultFlattenPasses = 1 + defaultTrashBlocks = 0 maxBlockSplits = math.MaxInt32 maxJunkJumps = 256 maxFlattenPasses = 4 + maxTrashBlocks = 1024 + + minTrashBlockStmts = 1 + maxTrashBlockStmts = 32 ) type directiveParamMap map[string]string @@ -173,6 +178,8 @@ func Obfuscate(fset *token.FileSet, ssaPkg *ssa.Package, files []*ast.File, obfR return ast.NewIdent(name) } + var trashGen *trashGenerator + for idx, ssaFunc := range ssaFuncs { params := ssaParams[idx] @@ -184,7 +191,15 @@ func Obfuscate(fset *token.FileSet, ssaPkg *ssa.Package, files []*ast.File, obfR } flattenHardening := params.StringSlice("flatten_hardening") + trashBlockCount := params.GetInt("trash_blocks", defaultTrashBlocks, maxTrashBlocks) + if trashBlockCount > 0 && trashGen == nil { + trashGen = newTrashGenerator(ssaPkg.Prog, funcConfig.ImportNameResolver, obfRand) + } + applyObfuscation := func(ssaFunc *ssa.Function) []dispatcherInfo { + if trashBlockCount > 0 { + addTrashBlockMarkers(ssaFunc, trashBlockCount, obfRand) + } for i := 0; i < split; i++ { if !applySplitting(ssaFunc, obfRand) { break // no more candidates for splitting @@ -229,6 +244,14 @@ func Obfuscate(fset *token.FileSet, ssaPkg *ssa.Package, files []*ast.File, obfR funcConfig.SsaValueRemap = nil } + if trashBlockCount > 0 { + funcConfig.MarkerInstrCallback = func(m map[string]types.Type) []ast.Stmt { + return trashGen.Generate(minTrashBlockStmts+obfRand.Intn(maxTrashBlockStmts-minTrashBlockStmts), m) + } + } else { + funcConfig.MarkerInstrCallback = nil + } + astFunc, err := ssa2ast.Convert(ssaFunc, funcConfig) if err != nil { return "", nil, nil, err diff --git a/internal/ctrlflow/transform.go b/internal/ctrlflow/transform.go index c2ce1805..f6f0f4b0 100644 --- a/internal/ctrlflow/transform.go +++ b/internal/ctrlflow/transform.go @@ -1,12 +1,14 @@ package ctrlflow import ( + "go/constant" "go/token" "go/types" mathrand "math/rand" "strconv" "golang.org/x/tools/go/ssa" + "mvdan.cc/garble/internal/ssa2ast" ) type blockMapping struct { @@ -218,6 +220,84 @@ func applySplitting(ssaFunc *ssa.Function, obfRand *mathrand.Rand) bool { return true } +func randomAlwaysFalseCond(obfRand *mathrand.Rand) (*ssa.Const, token.Token, *ssa.Const) { + tokens := []token.Token{token.EQL, token.NEQ, token.LSS, token.LEQ, token.GTR, token.GEQ} + + val1, val2 := constant.MakeInt64(int64(obfRand.Int31())), constant.MakeInt64(int64(obfRand.Int31())) + + var candidates []token.Token + for _, t := range tokens { + if !constant.Compare(val1, t, val2) { + candidates = append(candidates, t) + } + } + + return ssa.NewConst(val1, types.Typ[types.Int]), candidates[obfRand.Intn(len(candidates))], ssa.NewConst(val2, types.Typ[types.Int]) +} + +// addTrashBlockMarkers adds unreachable blocks with ssa2ast.MarkerInstr to further generate trash statements +func addTrashBlockMarkers(ssaFunc *ssa.Function, count int, obfRand *mathrand.Rand) { + var candidates []*ssa.BasicBlock + for _, block := range ssaFunc.Blocks { + if len(block.Succs) > 0 { + candidates = append(candidates, block) + } + } + + if len(candidates) == 0 { + return + } + + for i := 0; i < count; i++ { + targetBlock := candidates[obfRand.Intn(len(candidates))] + succsIdx := obfRand.Intn(len(targetBlock.Succs)) + succs := targetBlock.Succs[succsIdx] + + val1, op, val2 := randomAlwaysFalseCond(obfRand) + phiInstr := &ssa.Phi{ + Edges: []ssa.Value{val1}, + } + setType(phiInstr, types.Typ[types.Int]) + + binOpInstr := &ssa.BinOp{ + X: phiInstr, + Op: op, + Y: val2, + } + setType(binOpInstr, types.Typ[types.Bool]) + + jmpInstr := &ssa.If{Cond: binOpInstr} + *binOpInstr.Referrers() = append(*binOpInstr.Referrers(), jmpInstr) + + trashBlock := &ssa.BasicBlock{ + Comment: "ctrflow.trash." + strconv.Itoa(targetBlock.Index), + Instrs: []ssa.Instruction{ + ssa2ast.MarkerInstr, + &ssa.Jump{}, + }, + } + setBlockParent(trashBlock, ssaFunc) + + trashBlockDispatch := &ssa.BasicBlock{ + Comment: "ctrflow.trash.cond." + strconv.Itoa(targetBlock.Index), + Instrs: []ssa.Instruction{ + phiInstr, + binOpInstr, + jmpInstr, + }, + Preds: []*ssa.BasicBlock{targetBlock}, + Succs: []*ssa.BasicBlock{trashBlock, succs}, + } + setBlockParent(trashBlockDispatch, ssaFunc) + targetBlock.Succs[succsIdx] = trashBlockDispatch + + trashBlock.Preds = []*ssa.BasicBlock{trashBlockDispatch, trashBlock} + trashBlock.Succs = []*ssa.BasicBlock{trashBlock} + + ssaFunc.Blocks = append(ssaFunc.Blocks, trashBlockDispatch, trashBlock) + } +} + func fixBlockIndexes(ssaFunc *ssa.Function) { for i, block := range ssaFunc.Blocks { block.Index = i diff --git a/internal/ctrlflow/trash.go b/internal/ctrlflow/trash.go new file mode 100644 index 00000000..448d0d60 --- /dev/null +++ b/internal/ctrlflow/trash.go @@ -0,0 +1,513 @@ +package ctrlflow + +import ( + "encoding/base32" + "encoding/base64" + "encoding/hex" + "fmt" + "go/ast" + "go/token" + "go/types" + "math" + mathrand "math/rand" + "strconv" + "strings" + + "golang.org/x/exp/maps" + "golang.org/x/tools/go/ssa" + ah "mvdan.cc/garble/internal/asthelper" + "mvdan.cc/garble/internal/ssa2ast" +) + +const ( + varProb = 0.6 + globalProb = 0.4 + assignVarProb = 0.3 + methodCallProb = 0.5 + + minMethodsForType = 2 + maxStringLen = 32 + minVarsForAssign = 2 + maxAssignVars = 4 + maxVariadicParams = 5 + + limitFunctionCount = 256 +) + +var stringEncoders = []func([]byte) string{ + hex.EncodeToString, + base64.StdEncoding.EncodeToString, + base64.URLEncoding.EncodeToString, + base32.HexEncoding.EncodeToString, + base32.StdEncoding.EncodeToString, +} + +var valueGenerators = map[types.Type]func(rand *mathrand.Rand, targetType types.Type) ast.Expr{ + types.Typ[types.Bool]: func(rand *mathrand.Rand, _ types.Type) ast.Expr { + var val string + if rand.Float32() > 0.5 { + val = "true" + } else { + val = "false" + } + return ast.NewIdent(val) + }, + types.Typ[types.String]: func(rand *mathrand.Rand, _ types.Type) ast.Expr { + buf := make([]byte, 1+rand.Intn(maxStringLen)) + rand.Read(buf) + + return ah.StringLit(stringEncoders[rand.Intn(len(stringEncoders))](buf)) + }, + types.Typ[types.UntypedNil]: func(rand *mathrand.Rand, _ types.Type) ast.Expr { + return ast.NewIdent("nil") + }, + types.Typ[types.Float32]: func(rand *mathrand.Rand, t types.Type) ast.Expr { + var val float32 + if basic, ok := t.(*types.Basic); ok && (basic.Kind() != types.Float32 && basic.Kind() != types.Float64) { + // If the target type is not float, generate float without fractional part for safe type conversion + val = float32(rand.Intn(math.MaxInt8)) + } else { + val = rand.Float32() + } + return &ast.BasicLit{ + Kind: token.FLOAT, + Value: strconv.FormatFloat(float64(val), 'f', -1, 32), + } + }, + types.Typ[types.Float64]: func(rand *mathrand.Rand, t types.Type) ast.Expr { + var val float64 + if basic, ok := t.(*types.Basic); ok && basic.Kind() != types.Float64 { + // If the target type is not float64, generate float without fractional part for safe type conversion + val = float64(rand.Intn(math.MaxInt8)) + } else { + val = rand.Float64() + } + return &ast.BasicLit{ + Kind: token.FLOAT, + Value: strconv.FormatFloat(val, 'f', -1, 64), + } + }, + types.Typ[types.Int]: func(rand *mathrand.Rand, t types.Type) ast.Expr { + maxValue := math.MaxInt32 + if basic, ok := t.(*types.Basic); ok { + // Int can be cast to any numeric type, but compiler checks for overflow when casting constants. + // To prevent this, limiting the maximum value + switch basic.Kind() { + case types.Int8, types.Byte: + maxValue = math.MaxInt8 + case types.Int16, types.Uint16: + maxValue = math.MaxInt16 + } + } + return &ast.BasicLit{ + Kind: token.INT, + Value: strconv.FormatInt(int64(rand.Intn(maxValue)), 10), + } + }, +} + +func isInternal(path string) bool { + return strings.HasSuffix(path, "/internal") || strings.HasPrefix(path, "internal/") || strings.Contains(path, "/internal/") +} + +func under(t types.Type) types.Type { + for { + if t == t.Underlying() { + return t + } + t = t.Underlying() + } +} + +func deref(typ types.Type) types.Type { + if ptr, ok := typ.(*types.Pointer); ok { + typ = ptr.Elem() + } + return typ +} + +func canConvert(from, to types.Type) bool { + i, isInterface := under(to).(*types.Interface) + if isInterface { + if ptr, ok := from.(*types.Pointer); ok { + from = ptr.Elem() + } + return types.Implements(from, i) + } + return types.ConvertibleTo(from, to) +} + +func isSupportedType(v types.Type) bool { + for t := range valueGenerators { + if canConvert(t, v) { + return true + } + } + return false +} + +func isGenericType(p types.Type) bool { + switch typ := p.(type) { + case *types.Named: + return typ.TypeParams() != nil + case *types.Signature: + return typ.TypeParams() != nil && typ.RecvTypeParams() == nil + } + return false +} + +func isSupportedSig(m *types.Func) bool { + sig := m.Type().(*types.Signature) + if isGenericType(sig) { + return false + } + for i := 0; i < sig.Params().Len(); i++ { + if !isSupportedType(sig.Params().At(i).Type()) { + return false + } + } + return true +} + +type trashGenerator struct { + importNameResolver ssa2ast.ImportNameResolver + rand *mathrand.Rand + typeConverter *ssa2ast.TypeConverter + globals []*types.Var + pkgFunctions [][]*types.Func + methodCache map[types.Type][]*types.Func +} + +func newTrashGenerator(ssaProg *ssa.Program, importNameResolver ssa2ast.ImportNameResolver, rand *mathrand.Rand) *trashGenerator { + t := &trashGenerator{ + importNameResolver: importNameResolver, + rand: rand, + typeConverter: ssa2ast.NewTypeConverted(importNameResolver), + methodCache: make(map[types.Type][]*types.Func), + } + t.initialize(ssaProg) + return t +} + +type definedVar struct { + Type types.Type + External bool + + Refs int + Ident *ast.Ident + Assign *ast.AssignStmt +} + +func (d *definedVar) AddRef() { + if !d.External { + d.Refs++ + } +} + +func (d *definedVar) HasRefs() bool { + return d.External || d.Refs > 0 +} + +func (t *trashGenerator) initialize(ssaProg *ssa.Program) { + for _, p := range ssaProg.AllPackages() { + if isInternal(p.Pkg.Path()) || p.Pkg.Name() == "main" { + continue + } + var pkgFuncs []*types.Func + for _, member := range p.Members { + if !token.IsExported(member.Name()) { + continue + } + switch m := member.(type) { + case *ssa.Global: + if !isGenericType(m.Type()) && m.Object() != nil { + t.globals = append(t.globals, m.Object().(*types.Var)) + } + case *ssa.Function: + if m.Signature.Recv() != nil || !isSupportedSig(m.Object().(*types.Func)) { + continue + } + + pkgFuncs = append(pkgFuncs, m.Object().(*types.Func)) + if len(pkgFuncs) > limitFunctionCount { + break + } + } + } + + if len(pkgFuncs) > 0 { + t.pkgFunctions = append(t.pkgFunctions, pkgFuncs) + } + } +} + +func (t *trashGenerator) convertExpr(from, to types.Type, expr ast.Expr) ast.Expr { + if types.AssignableTo(from, to) { + return expr + } + + castExpr, err := t.typeConverter.Convert(to) + if err != nil { + panic(err) + } + return ah.CallExpr(&ast.ParenExpr{X: castExpr}, expr) +} + +func (t *trashGenerator) chooseRandomVar(typ types.Type, vars map[string]*definedVar) ast.Expr { + var candidates []string + for name, d := range vars { + if canConvert(d.Type, typ) { + candidates = append(candidates, name) + } + } + if len(candidates) == 0 { + return nil + } + + targetVarName := candidates[t.rand.Intn(len(candidates))] + targetVar := vars[targetVarName] + targetVar.AddRef() + + return t.convertExpr(targetVar.Type, typ, ast.NewIdent(targetVarName)) +} + +func (t *trashGenerator) chooseRandomGlobal(typ types.Type) ast.Expr { + var candidates []*types.Var + for _, global := range t.globals { + if canConvert(global.Type(), typ) { + candidates = append(candidates, global) + } + } + if len(candidates) == 0 { + return nil + } + + targetGlobal := candidates[t.rand.Intn(len(candidates))] + + var globalExpr ast.Expr + if pkgIdent := t.importNameResolver(targetGlobal.Pkg()); pkgIdent != nil { + globalExpr = ah.SelectExpr(pkgIdent, ast.NewIdent(targetGlobal.Name())) + } else { + globalExpr = ast.NewIdent(targetGlobal.Name()) + } + return t.convertExpr(targetGlobal.Type(), typ, globalExpr) +} + +func (t *trashGenerator) generateRandomConst(p types.Type, rand *mathrand.Rand) ast.Expr { + var candidates []types.Type + for typ := range valueGenerators { + if canConvert(typ, p) { + candidates = append(candidates, typ) + } + } + + if len(candidates) == 0 { + panic(fmt.Errorf("unsupported type: %v", p)) + } + + generatorType := candidates[rand.Intn(len(candidates))] + generator := valueGenerators[generatorType] + return t.convertExpr(generatorType, p, generator(rand, under(p))) +} + +func (t *trashGenerator) generateRandomValue(typ types.Type, vars map[string]*definedVar) ast.Expr { + if t.rand.Float32() < varProb { + if expr := t.chooseRandomVar(typ, vars); expr != nil { + return expr + } + } + if t.rand.Float32() < globalProb { + if expr := t.chooseRandomGlobal(typ); expr != nil { + return expr + } + } + return t.generateRandomConst(typ, t.rand) +} + +func (t *trashGenerator) cacheMethods(vars map[string]*definedVar) { + for _, d := range vars { + typ := deref(d.Type) + if _, ok := t.methodCache[typ]; ok { + continue + } + + var methods []*types.Func + switch typ := typ.(type) { + case *types.Named: + for i := 0; i < typ.NumMethods(); i++ { + if m := typ.Method(i); token.IsExported(m.Name()) && isSupportedSig(m) { + methods = append(methods, m) + if len(methods) > limitFunctionCount { + break + } + } + } + case *types.Interface: + for i := 0; i < typ.NumMethods(); i++ { + if m := typ.Method(i); token.IsExported(m.Name()) && isSupportedSig(m) { + methods = append(methods, m) + if len(methods) > limitFunctionCount { + break + } + } + } + } + if len(methods) < minMethodsForType { + methods = nil + } + t.methodCache[typ] = methods + } +} + +func (t *trashGenerator) chooseRandomMethod(vars map[string]*definedVar) (*types.Func, string) { + t.cacheMethods(vars) + + groupedCandidates := make(map[types.Type][]string) + for name, v := range vars { + typ := deref(v.Type) + if len(t.methodCache[typ]) == 0 { + continue + } + groupedCandidates[typ] = append(groupedCandidates[typ], name) + } + + if len(groupedCandidates) == 0 { + return nil, "" + } + + candidateTypes := maps.Keys(groupedCandidates) + candidateType := candidateTypes[t.rand.Intn(len(candidateTypes))] + candidates := groupedCandidates[candidateType] + + name := candidates[t.rand.Intn(len(candidates))] + vars[name].AddRef() + + methods := t.methodCache[candidateType] + return methods[t.rand.Intn(len(methods))], name +} + +func (t *trashGenerator) generateCall(vars map[string]*definedVar) ast.Stmt { + var ( + targetFunc *types.Func + targetRecvName string + ) + if t.rand.Float32() < methodCallProb { + targetFunc, targetRecvName = t.chooseRandomMethod(vars) + } + + if targetFunc == nil { + targetPkg := t.pkgFunctions[t.rand.Intn(len(t.pkgFunctions))] + targetFunc = targetPkg[t.rand.Intn(len(targetPkg))] + } + + var args []ast.Expr + + targetSig := targetFunc.Type().(*types.Signature) + params := targetSig.Params() + for i := 0; i < params.Len(); i++ { + param := params.At(i) + if !targetSig.Variadic() || i != params.Len()-1 { + args = append(args, t.generateRandomValue(param.Type(), vars)) + continue + } + + variadicCount := t.rand.Intn(maxVariadicParams) + for i := 0; i < variadicCount; i++ { + sliceTyp, ok := param.Type().(*types.Slice) + if !ok { + panic(fmt.Errorf("unsupported variadic type: %v", param.Type())) + } + args = append(args, t.generateRandomValue(sliceTyp.Elem(), vars)) + } + } + + var fun ast.Expr + if targetSig.Recv() != nil { + if len(targetRecvName) == 0 { + panic("recv var must be set") + } + fun = ah.SelectExpr(ast.NewIdent(targetRecvName), ast.NewIdent(targetFunc.Name())) + } else if pkgIdent := t.importNameResolver(targetFunc.Pkg()); pkgIdent != nil { + fun = ah.SelectExpr(pkgIdent, ast.NewIdent(targetFunc.Name())) + } else { + fun = ast.NewIdent(targetFunc.Name()) + } + + callExpr := ah.CallExpr(fun, args...) + results := targetSig.Results() + if results == nil { + return ah.ExprStmt(callExpr) + } + + assignStmt := &ast.AssignStmt{ + Tok: token.ASSIGN, + Rhs: []ast.Expr{callExpr}, + } + + for i := 0; i < results.Len(); i++ { + ident := ast.NewIdent(getRandomName(t.rand)) + vars[ident.Name] = &definedVar{ + Type: results.At(i).Type(), + Ident: ident, + Assign: assignStmt, + } + assignStmt.Lhs = append(assignStmt.Lhs, ident) + } + return assignStmt +} + +func (t *trashGenerator) generateAssign(vars map[string]*definedVar) ast.Stmt { + var varNames []string + for name, d := range vars { + if d.HasRefs() && isSupportedType(d.Type) { + varNames = append(varNames, name) + } + } + t.rand.Shuffle(len(varNames), func(i, j int) { + varNames[i], varNames[j] = varNames[j], varNames[i] + }) + + varCount := 1 + t.rand.Intn(maxAssignVars) + if varCount > len(varNames) { + varCount = len(varNames) + } + + assignStmt := &ast.AssignStmt{ + Tok: token.ASSIGN, + } + for _, name := range varNames[:varCount] { + d := vars[name] + d.AddRef() + + assignStmt.Lhs = append(assignStmt.Lhs, ast.NewIdent(name)) + assignStmt.Rhs = append(assignStmt.Rhs, t.generateRandomValue(d.Type, vars)) + } + return assignStmt +} + +func (t *trashGenerator) Generate(statementCount int, externalVars map[string]types.Type) []ast.Stmt { + vars := make(map[string]*definedVar) + for name, typ := range externalVars { + vars[name] = &definedVar{Type: typ, External: true} + } + + var stmts []ast.Stmt + for i := 0; i < statementCount; i++ { + var stmt ast.Stmt + if len(vars) >= minVarsForAssign && t.rand.Float32() < assignVarProb { + stmt = t.generateAssign(vars) + } else { + stmt = t.generateCall(vars) + } + stmts = append(stmts, stmt) + } + + for _, v := range vars { + if v.Ident != nil && !v.HasRefs() { + v.Ident.Name = "_" + } else if v.Assign != nil { + v.Assign.Tok = token.DEFINE + } + } + return stmts +} diff --git a/internal/ctrlflow/trash_test.go b/internal/ctrlflow/trash_test.go new file mode 100644 index 00000000..028d6de2 --- /dev/null +++ b/internal/ctrlflow/trash_test.go @@ -0,0 +1,100 @@ +package ctrlflow + +import ( + "fmt" + "go/ast" + "go/importer" + "go/printer" + "go/token" + "go/types" + mathrand "math/rand" + "os" + "strconv" + "testing" + + "golang.org/x/tools/go/ast/astutil" + "golang.org/x/tools/go/ssa" + "golang.org/x/tools/go/ssa/ssautil" + ah "mvdan.cc/garble/internal/asthelper" +) + +func Test_generateTrashBlock(t *testing.T) { + const ( + seed = 7777 + stmtCount = 1024 + ) + + fset := token.NewFileSet() + buildPkg := func(f *ast.File) *ssa.Package { + ssaPkg, _, err := ssautil.BuildPackage(&types.Config{Importer: importer.Default()}, fset, types.NewPackage("test/main", ""), []*ast.File{f}, 0) + if err != nil { + t.Fatal(err) + } + return ssaPkg + } + + body := &ast.BlockStmt{} + file := &ast.File{ + Name: ast.NewIdent("main"), + Decls: []ast.Decl{ + &ast.GenDecl{ + Tok: token.IMPORT, + Specs: []ast.Spec{ + &ast.ImportSpec{ + Name: ast.NewIdent("_"), + Path: ah.StringLit("os"), + }, + &ast.ImportSpec{ + Name: ast.NewIdent("_"), + Path: ah.StringLit("math"), + }, + &ast.ImportSpec{ + Name: ast.NewIdent("_"), + Path: ah.StringLit("fmt"), + }, + }, + }, + &ast.FuncDecl{ + Name: ast.NewIdent("main"), + Type: &ast.FuncType{Params: &ast.FieldList{}}, + Body: body, + }, + }, + } + beforeSsaPkg := buildPkg(file) + + imports := make(map[string]string) + gen := newTrashGenerator(beforeSsaPkg.Prog, func(pkg *types.Package) *ast.Ident { + if pkg == nil || pkg.Path() == beforeSsaPkg.Pkg.Path() { + return nil + } + + name, ok := imports[pkg.Path()] + if !ok { + name = importPrefix + strconv.Itoa(len(imports)) + imports[pkg.Path()] = name + astutil.AddNamedImport(fset, file, name, pkg.Path()) + } + return ast.NewIdent(name) + }, mathrand.New(mathrand.NewSource(seed))) + + predefinedArgs := make(map[string]types.Type) + for i := types.Bool; i < types.UnsafePointer; i++ { + name, typ := fmt.Sprintf("v%d", i), types.Typ[i] + predefinedArgs[name] = typ + body.List = append(body.List, + &ast.DeclStmt{Decl: &ast.GenDecl{ + Tok: token.VAR, + Specs: []ast.Spec{&ast.ValueSpec{ + Names: []*ast.Ident{ast.NewIdent(name)}, + Type: ast.NewIdent(typ.Name()), + }}, + }}, + ah.AssignStmt(ast.NewIdent("_"), ast.NewIdent(name)), + ) + } + + body.List = append(body.List, gen.Generate(stmtCount, predefinedArgs)...) + printer.Fprint(os.Stdout, fset, file) + buildPkg(file) +} diff --git a/internal/ssa2ast/func.go b/internal/ssa2ast/func.go index 824b03b6..43fe183c 100644 --- a/internal/ssa2ast/func.go +++ b/internal/ssa2ast/func.go @@ -15,7 +15,11 @@ import ( ah "mvdan.cc/garble/internal/asthelper" ) -var ErrUnsupported = errors.New("unsupported") +var ( + ErrUnsupported = errors.New("unsupported") + + MarkerInstr = &ssa.Panic{} +) type NameType int @@ -33,6 +37,10 @@ type ConverterConfig struct { // Note: Replacing ssa.Expr does not guarantee the correctness of the generated code. // When using it, strictly adhere to the value types. SsaValueRemap map[ssa.Value]ast.Expr + + // MarkerInstrCallback is called every time a MarkerInstr instruction is encountered. + // Callback result is inserted into ast as is + MarkerInstrCallback func(vars map[string]types.Type) []ast.Stmt } func DefaultConfig() *ConverterConfig { @@ -50,11 +58,12 @@ func defaultImportNameResolver(pkg *types.Package) *ast.Ident { } type funcConverter struct { - importNameResolver ImportNameResolver - tc *typeConverter - namePrefix string - valueNameMap map[ssa.Value]string - ssaValueRemap map[ssa.Value]ast.Expr + importNameResolver ImportNameResolver + tc *TypeConverter + namePrefix string + valueNameMap map[ssa.Value]string + ssaValueRemap map[ssa.Value]ast.Expr + markerInstrCallback func(map[string]types.Type) []ast.Stmt } func Convert(ssaFunc *ssa.Function, cfg *ConverterConfig) (*ast.FuncDecl, error) { @@ -63,11 +72,12 @@ func Convert(ssaFunc *ssa.Function, cfg *ConverterConfig) (*ast.FuncDecl, error) func newFuncConverter(cfg *ConverterConfig) *funcConverter { return &funcConverter{ - importNameResolver: cfg.ImportNameResolver, - tc: &typeConverter{resolver: cfg.ImportNameResolver}, - namePrefix: cfg.NamePrefix, - valueNameMap: make(map[ssa.Value]string), - ssaValueRemap: cfg.SsaValueRemap, + importNameResolver: cfg.ImportNameResolver, + tc: &TypeConverter{resolver: cfg.ImportNameResolver}, + namePrefix: cfg.NamePrefix, + valueNameMap: make(map[ssa.Value]string), + ssaValueRemap: cfg.SsaValueRemap, + markerInstrCallback: cfg.MarkerInstrCallback, } } @@ -492,6 +502,14 @@ func (fc *funcConverter) convertBlock(astFunc *AstFunc, ssaBlock *ssa.BasicBlock } for _, instr := range ssaBlock.Instrs[:len(ssaBlock.Instrs)-1] { + if instr == MarkerInstr { + if fc.markerInstrCallback == nil { + panic("marker callback is nil") + } + astBlock.Body = append(astBlock.Body, nil) + continue + } + var stmt ast.Stmt switch instr := instr.(type) { case *ssa.Alloc: @@ -1139,6 +1157,18 @@ func (fc *funcConverter) convertToStmts(ssaFunc *ssa.Function) ([]ast.Stmt, erro } for _, block := range f.Blocks { + if fc.markerInstrCallback != nil { + var newBody []ast.Stmt + for _, stmt := range block.Body { + if stmt != nil { + newBody = append(newBody, stmt) + } else { + newBody = append(newBody, fc.markerInstrCallback(f.Vars)...) + } + } + block.Body = newBody + } + blockStmts := &ast.BlockStmt{List: append(block.Body, block.Phi...)} blockStmts.List = append(blockStmts.List, block.Exit) if block.HasRefs { diff --git a/internal/ssa2ast/polyfill.go b/internal/ssa2ast/polyfill.go index 28791a46..3aa2dbef 100644 --- a/internal/ssa2ast/polyfill.go +++ b/internal/ssa2ast/polyfill.go @@ -6,7 +6,7 @@ import ( "go/types" ) -func makeMapIteratorPolyfill(tc *typeConverter, mapType *types.Map) (ast.Expr, types.Type, error) { +func makeMapIteratorPolyfill(tc *TypeConverter, mapType *types.Map) (ast.Expr, types.Type, error) { keyTypeExpr, err := tc.Convert(mapType.Key()) if err != nil { return nil, nil, err diff --git a/internal/ssa2ast/type.go b/internal/ssa2ast/type.go index ab8fb7b5..fc7f70f7 100644 --- a/internal/ssa2ast/type.go +++ b/internal/ssa2ast/type.go @@ -9,11 +9,15 @@ import ( "strconv" ) -type typeConverter struct { +type TypeConverter struct { resolver ImportNameResolver } -func (tc *typeConverter) Convert(t types.Type) (ast.Expr, error) { +func NewTypeConverted(resolver ImportNameResolver) *TypeConverter { + return &TypeConverter{resolver: resolver} +} + +func (tc *TypeConverter) Convert(t types.Type) (ast.Expr, error) { switch typ := t.(type) { case *types.Array: eltExpr, err := tc.Convert(typ.Elem()) diff --git a/internal/ssa2ast/type_test.go b/internal/ssa2ast/type_test.go index 403471be..f07a55e3 100644 --- a/internal/ssa2ast/type_test.go +++ b/internal/ssa2ast/type_test.go @@ -95,7 +95,7 @@ func TestTypeToExpr(t *testing.T) { f, _, info, _ := mustParseAndTypeCheckFile(typesSrc) name, structAst := findStruct(f, "exampleStruct") obj := info.Defs[name] - fc := &typeConverter{resolver: defaultImportNameResolver} + fc := &TypeConverter{resolver: defaultImportNameResolver} convAst, err := fc.Convert(obj.Type().Underlying()) if err != nil { t.Fatal(err) diff --git a/testdata/script/ctrlflow.txtar b/testdata/script/ctrlflow.txtar index 085b12db..a8090155 100644 --- a/testdata/script/ctrlflow.txtar +++ b/testdata/script/ctrlflow.txtar @@ -78,7 +78,7 @@ func multiHardeningTest(i int) int { return multiply(i); } -//garble:controlflow flatten_passes=1 junk_jumps=10 block_splits=10 +//garble:controlflow flatten_passes=1 junk_jumps=10 block_splits=10 trash_blocks=32 func main() { // Reference to the unexported interface triggers creation of a new interface // with a list of all functions of the private interface