diff --git a/pkg/workload/tpcc/tpcc.go b/pkg/workload/tpcc/tpcc.go index be3e8cc31412..43e28f2d59f9 100644 --- a/pkg/workload/tpcc/tpcc.go +++ b/pkg/workload/tpcc/tpcc.go @@ -15,6 +15,7 @@ import ( gosql "database/sql" "fmt" "net/url" + "strconv" "strings" "sync" "time" @@ -48,11 +49,11 @@ type tpcc struct { // is the value of C for the item id generator. See 2.1.6. cLoad, cCustomerID, cItemID int - mix string - doWaits bool - workers int - fks bool - dbOverride string + mix string + waitFraction float64 + workers int + fks bool + dbOverride string txInfos []txInfo // deck contains indexes into the txInfos slice. @@ -83,6 +84,45 @@ type tpcc struct { localsPool *sync.Pool } +type waitSetter struct { + val *float64 +} + +// Set implements the pflag.Value interface. +func (w *waitSetter) Set(val string) error { + switch strings.ToLower(val) { + case "true", "on": + *w.val = 1.0 + case "false", "off": + *w.val = 0.0 + default: + f, err := strconv.ParseFloat(val, 64) + if err != nil { + return err + } + if f < 0 { + return errors.New("cannot set --wait to a negative value") + } + *w.val = f + } + return nil +} + +// Type implements the pflag.Value interface +func (*waitSetter) Type() string { return "0.0/false - 1.0/true" } + +// String implements the pflag.Value interface. +func (w *waitSetter) String() string { + switch *w.val { + case 0: + return "false" + case 1: + return "true" + default: + return fmt.Sprintf("%f", *w.val) + } +} + func init() { workload.Register(tpccMeta) } @@ -127,7 +167,8 @@ var tpccMeta = workload.Meta{ g.flags.StringVar(&g.mix, `mix`, `newOrder=10,payment=10,orderStatus=1,delivery=1,stockLevel=1`, `Weights for the transaction mix. The default matches the TPCC spec.`) - g.flags.BoolVar(&g.doWaits, `wait`, true, `Run in wait mode (include think/keying sleeps)`) + g.waitFraction = 1.0 + g.flags.Var(&waitSetter{&g.waitFraction}, `wait`, `Wait mode (include think/keying sleeps): 1/true for tpcc-standard wait, 0/false for no waits, other factors also allowed`) g.flags.StringVar(&g.dbOverride, `db`, ``, `Override for the SQL database to use. If empty, defaults to the generator name`) g.flags.IntVar(&g.workers, `workers`, 0, fmt.Sprintf( @@ -200,15 +241,15 @@ func (w *tpcc) Hooks() workload.Hooks { // waiting, we only use up to a set number of connections per warehouse. // This isn't mandated by the spec, but opening a connection per worker // when they each spend most of their time waiting is wasteful. - if !w.doWaits { + if w.waitFraction == 0 { w.numConns = w.workers } else { w.numConns = w.activeWarehouses * numConnsPerWarehouse } } - if w.doWaits && w.workers != w.activeWarehouses*numWorkersPerWarehouse { - return errors.Errorf(`--wait=true and --warehouses=%d requires --workers=%d`, + if w.waitFraction > 0 && w.workers != w.activeWarehouses*numWorkersPerWarehouse { + return errors.Errorf(`--wait > 0 and --warehouses=%d requires --workers=%d`, w.activeWarehouses, w.warehouses*numWorkersPerWarehouse) } diff --git a/pkg/workload/tpcc/worker.go b/pkg/workload/tpcc/worker.go index ec1720b4edae..ea1cbdef8417 100644 --- a/pkg/workload/tpcc/worker.go +++ b/pkg/workload/tpcc/worker.go @@ -178,12 +178,10 @@ func (w *worker) run(ctx context.Context) error { w.permIdx++ warehouseID := w.warehouse - if w.config.doWaits { - // Wait out the entire keying and think time even if the context is - // expired. This prevents all workers from immediately restarting when - // the workload's ramp period expires, which can overload a cluster. - time.Sleep(time.Duration(txInfo.keyingTime) * time.Second) - } + // Wait out the entire keying and think time even if the context is + // expired. This prevents all workers from immediately restarting when + // the workload's ramp period expires, which can overload a cluster. + time.Sleep(time.Duration(float64(txInfo.keyingTime) * float64(time.Second) * w.config.waitFraction)) // Run transactions with a background context because we don't want to // cancel them when the context expires. Instead, let them finish normally @@ -197,16 +195,14 @@ func (w *worker) run(ctx context.Context) error { w.hists.Get(txInfo.name).Record(elapsed) } - if w.config.doWaits { - // 5.2.5.4: Think time is taken independently from a negative exponential - // distribution. Think time = -log(r) * u, where r is a uniform random number - // between 0 and 1 and u is the mean think time per operation. - // Each distribution is truncated at 10 times its mean value. - thinkTime := -math.Log(rand.Float64()) * txInfo.thinkTime - if thinkTime > (txInfo.thinkTime * 10) { - thinkTime = txInfo.thinkTime * 10 - } - time.Sleep(time.Duration(thinkTime) * time.Second) + // 5.2.5.4: Think time is taken independently from a negative exponential + // distribution. Think time = -log(r) * u, where r is a uniform random number + // between 0 and 1 and u is the mean think time per operation. + // Each distribution is truncated at 10 times its mean value. + thinkTime := -math.Log(rand.Float64()) * txInfo.thinkTime + if thinkTime > (txInfo.thinkTime * 10) { + thinkTime = txInfo.thinkTime * 10 } + time.Sleep(time.Duration(thinkTime * float64(time.Second) * w.config.waitFraction)) return ctx.Err() }