diff --git a/main.go b/main.go index 91566e9a9..beed9e982 100644 --- a/main.go +++ b/main.go @@ -18,18 +18,19 @@ import ( "github.com/zmap/zlint/zlint/ringbuff" ) -const CHUNKSIZE int = 10000 //number of certs per work unit, must be >=1 -const THREADS int = 4 //number of processing threads for --threads mode, must be >=1 +const CHUNKSIZE int = 10000 //number of certs per work unit, must be >=1 +const DEFAULT_THREADS uint = 4 //default number of processing threads for -threads mode, must be >=1 var ( //flags - inPath string - outPath string - outStat string - multi bool - threaded bool + inPath string + outPath string + outStat string + multi bool + threaded bool + numThreads uint ) -var ( //sync values for --threads +var ( //sync values for -threads inBuffer ringbuff.RingBuffer outBuffer ringbuff.RingBuffer poisonBarrier sync.WaitGroup //used prevent outBuffer from being poisoned before Enqueueing is complete @@ -44,7 +45,8 @@ func init() { flag.StringVar(&outPath, "out-file", "-", "File path for the output JSON.") flag.StringVar(&outStat, "out-stat", "-", "File path for the output stats.") flag.BoolVar(&multi, "multi", false, "Use this flag to specify inserting many certs at once. Certs in this mode must be Base64 encoded DER strings, one per line.") - flag.BoolVar(&threaded, "threads", false, "Use this flag to specify that --multi mode runs multi-threaded. This has no effect otherwise.") + flag.BoolVar(&threaded, "threads", false, "Use this flag to specify that -multi mode runs multi-threaded. This has no effect otherwise.") + flag.UintVar(&numThreads, "num-threads", DEFAULT_THREADS, "Use this flag to specify the number of threads in -threads mode. This has no effect otherwise.") flag.Parse() } @@ -170,13 +172,13 @@ func threadMode() { inBuffer.Init(30) outBuffer.Init(40) mainWait = sync.NewCond(&exittex) - poisonBarrier.Add(1 + THREADS) //requires all processing threads AND the reader to Done() + poisonBarrier.Add(1 + int(numThreads)) //requires all processing threads AND the reader to Done() //initiate reader go readChunks() //initiate processing - for i := 1; i <= THREADS; i++ { + for i := 1; i <= int(numThreads); i++ { go processChunks() }