Skip to content

Commit

Permalink
optimize generator by caching sum and ordered keys
Browse files Browse the repository at this point in the history
BenchmarkChain_GenerateDeterministic:
before: 635533 ns/op
 after: 89687 ns/op
Time savings: 86%.
  • Loading branch information
starius committed Aug 4, 2023
1 parent 01c8029 commit 68aee51
Showing 1 changed file with 38 additions and 14 deletions.
52 changes: 38 additions & 14 deletions gomarkov.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,18 @@ const (
EndToken = "^"
)

type preprocessedArray struct {
sparseArray

sum int
orderedKeys []int
}

// Chain is a markov chain instance
type Chain struct {
Order int
statePool *spool
frequencyMat map[int]sparseArray
frequencyMat map[int]preprocessedArray
lock *sync.RWMutex
}

Expand All @@ -40,10 +47,14 @@ var defaultPrng = rand.New(rand.NewSource(time.Now().UnixNano()))

// MarshalJSON ...
func (chain Chain) MarshalJSON() ([]byte, error) {
frequencyMat := make(map[int]sparseArray, len(chain.frequencyMat))
for k, v := range chain.frequencyMat {
frequencyMat[k] = v.sparseArray
}
obj := chainJSON{
chain.Order,
chain.statePool.stringMap,
chain.frequencyMat,
frequencyMat,
}
return json.Marshal(obj)
}
Expand All @@ -64,7 +75,14 @@ func (chain *Chain) UnmarshalJSON(b []byte) error {
stringMap: obj.SpoolMap,
intMap: intMap,
}
chain.frequencyMat = obj.FreqMat
chain.frequencyMat = make(map[int]preprocessedArray, len(obj.FreqMat))
for k, v := range obj.FreqMat {
chain.frequencyMat[k] = preprocessedArray{
sparseArray: v,
sum: v.sum(),
orderedKeys: v.orderedKeys(),
}
}
chain.lock = new(sync.RWMutex)
return nil
}
Expand All @@ -76,7 +94,7 @@ func NewChain(order int) *Chain {
stringMap: make(map[string]int),
intMap: make(map[int]string),
}
chain.frequencyMat = make(map[int]sparseArray, 0)
chain.frequencyMat = make(map[int]preprocessedArray)
chain.lock = new(sync.RWMutex)
return &chain
}
Expand All @@ -95,10 +113,18 @@ func (chain *Chain) Add(input []string) {
currentIndex := chain.statePool.add(pair.CurrentState.key())
nextIndex := chain.statePool.add(pair.NextState)
chain.lock.Lock()
if chain.frequencyMat[currentIndex] == nil {
chain.frequencyMat[currentIndex] = make(sparseArray, 0)
pa, has := chain.frequencyMat[currentIndex]
if !has {
pa = preprocessedArray{
sparseArray: make(sparseArray),
}
}
pa.sparseArray[nextIndex]++
pa.sum++
if len(pa.orderedKeys) != len(pa.sparseArray) {
pa.orderedKeys = pa.sparseArray.orderedKeys()
}
chain.frequencyMat[currentIndex][nextIndex]++
chain.frequencyMat[currentIndex] = pa
chain.lock.Unlock()
}
}
Expand All @@ -114,8 +140,8 @@ func (chain *Chain) TransitionProbability(next string, current NGram) (float64,
return 0, nil
}
arr := chain.frequencyMat[currentIndex]
sum := float64(arr.sum())
freq := float64(arr[nextIndex])
sum := float64(arr.sum)
freq := float64(arr.sparseArray[nextIndex])
return freq / sum, nil
}

Expand All @@ -139,11 +165,9 @@ func (chain *Chain) GenerateDeterministic(current NGram, prng PRNG) (string, err
return "", fmt.Errorf("Unknown ngram %v", current)
}
arr := chain.frequencyMat[currentIndex]
sum := arr.sum()
randN := prng.Intn(sum)
keys := arr.orderedKeys()
for _, key := range keys {
freq := arr[key]
randN := prng.Intn(arr.sum)
for _, key := range arr.orderedKeys {
freq := arr.sparseArray[key]
randN -= freq
if randN <= 0 {
return chain.statePool.intMap[key], nil
Expand Down

0 comments on commit 68aee51

Please sign in to comment.