diff --git a/code/failpoint.go b/code/failpoint.go index 0992ea0..6ea7de7 100644 --- a/code/failpoint.go +++ b/code/failpoint.go @@ -27,6 +27,10 @@ type Failpoint struct { // whitespace for padding ws string + + // if true, do not acquire read lock on failpoints + // useful for "continue" + eval bool } // newFailpoint makes a new failpoint based on the a line containing a @@ -38,10 +42,11 @@ func newFailpoint(l string) (*Failpoint, error) { } cmd := strings.SplitAfter(l, "// gofail:")[1] fields := strings.Fields(cmd) - if len(fields) != 3 || fields[0] != "var" { + if len(fields) < 3 || fields[0] != "var" { return nil, fmt.Errorf("failpoint: malformed comment header %q", l) } - return &Failpoint{name: fields[1], varType: fields[2], ws: strings.Split(l, "//")[0]}, nil + eval := len(fields) == 4 && fields[3] == "eval" + return &Failpoint{name: fields[1], varType: fields[2], ws: strings.Split(l, "//")[0], eval: eval}, nil } // flush writes the failpoint code to a buffer @@ -53,14 +58,19 @@ func (fp *Failpoint) flush(dst io.Writer) error { } func (fp *Failpoint) hdr(varname string) string { - hdr := fp.ws + "if v" + fp.name + ", __fpErr := " + fp.Runtime() + ".Acquire(); __fpErr == nil { " - hdr = hdr + "defer " + fp.Runtime() + ".Release(); " + hdr := fp.ws + "if v" + fp.name + ", __fpErr := " + if !fp.eval { + hdr += fp.Runtime() + ".Acquire(); __fpErr == nil { " + hdr += "defer " + fp.Runtime() + ".Release(); " + } else { + hdr += fp.Runtime() + ".Eval(); __fpErr == nil { " + } if fp.varType == "struct{}" { // unused varname = "_" } return hdr + varname + ", __fpTypeOK := v" + fp.name + - ".(" + fp.varType + "); if !__fpTypeOK { goto __badType" + fp.name + "} " + ".(" + fp.varType + "); if !__fpTypeOK { goto __badType" + fp.name + " } else { continue __fp_" + fp.name + " }" } func (fp *Failpoint) footer() string { @@ -69,6 +79,10 @@ func (fp *Failpoint) footer() string { } func (fp *Failpoint) flushSingle(dst io.Writer) error { + if fp.varType == "continue" { + _, cerr := io.WriteString(dst, "__fp_"+fp.name+":\n") + return cerr + } if _, err := io.WriteString(dst, fp.hdr("_")); err != nil { return err } diff --git a/code/rewrite.go b/code/rewrite.go index dcc4405..38a69e0 100644 --- a/code/rewrite.go +++ b/code/rewrite.go @@ -48,7 +48,9 @@ func ToFailpoints(wdst io.Writer, rsrc io.Reader) (fps []*Failpoint, err error) continue } else { curfp.flush(dst) - fps = append(fps, curfp) + if curfp.varType != "continue" { + fps = append(fps, curfp) + } curfp = nil } } else { @@ -81,6 +83,14 @@ func ToComments(wdst io.Writer, rsrc io.Reader) (fps []*Failpoint, err error) { err = rerr lTrim := strings.TrimSpace(l) + isContinue := strings.HasPrefix(lTrim, "__fp_") && lTrim[len(lTrim)-1] == ':' + if isContinue { + n := strings.Split(strings.Split(lTrim, "__fp_")[1], ":")[0] + dst.WriteString("\t// gofail: var " + n + " continue" + "\n") + fps = append(fps, &Failpoint{name: n, varType: "continue"}) + continue + } + if unmatchedBraces > 0 { opening, closing := numBraces(l) unmatchedBraces += opening - closing @@ -98,6 +108,9 @@ func ToComments(wdst io.Writer, rsrc io.Reader) (fps []*Failpoint, err error) { ws = strings.Split(l, "i")[0] n := strings.Split(strings.Split(l, "__fp_")[1], ".")[0] t := strings.Split(strings.Split(l, ".(")[1], ")")[0] + if strings.Contains(lTrim, n+"."+"Eval();") { + t += " eval" + } dst.WriteString(ws + "// gofail: var " + n + " " + t + "\n") if !strings.Contains(l, "; __badType") { // not single liner diff --git a/runtime/failpoint.go b/runtime/failpoint.go index 0eee542..4e21d5f 100644 --- a/runtime/failpoint.go +++ b/runtime/failpoint.go @@ -30,6 +30,21 @@ func NewFailpoint(pkg, name string) *Failpoint { return fp } +// Eval merely evaluates the failpoint terms without +// acquiring the read lock. +func (fp *Failpoint) Eval() (interface{}, error) { + fp.mu.RLock() + defer fp.mu.RUnlock() + if fp.t == nil { + return nil, ErrDisabled + } + v, err := fp.t.eval() + if v == nil { + err = ErrDisabled + } + return v, err +} + // Acquire gets evalutes the failpoint terms; if the failpoint // is active, it will return a value. Otherwise, returns a non-nil error. func (fp *Failpoint) Acquire() (interface{}, error) {