From 27cf60abd7b3c579589d01eed6b7b16e7a02a872 Mon Sep 17 00:00:00 2001 From: Tony Spataro Date: Sun, 2 Apr 2023 14:45:15 -0700 Subject: [PATCH 1/6] Add GenerateDeterministic --- gomarkov.go | 54 ++++++++++++++++++++++++++++++++++++++---------- gomarkov_test.go | 30 +++++++++++++++++++++++++++ helpers.go | 20 ++++++++++++++---- 3 files changed, 89 insertions(+), 15 deletions(-) diff --git a/gomarkov.go b/gomarkov.go index 5384949..4dce351 100644 --- a/gomarkov.go +++ b/gomarkov.go @@ -9,14 +9,14 @@ import ( "time" ) -//Tokens are wrapped around a sequence of words to maintain the -//start and end transition counts +// Tokens are wrapped around a sequence of words to maintain the +// start and end transition counts const ( StartToken = "$" EndToken = "^" ) -//Chain is a markov chain instance +// Chain is a markov chain instance type Chain struct { Order int statePool *spool @@ -24,13 +24,21 @@ type Chain struct { lock *sync.RWMutex } +// PRNG is a pseudo-random number generator compatible with math/rand interfaces. +type PRNG interface { + // Intn returns a number number in the half-open interval [0,n) + Intn(int) int +} + type chainJSON struct { Order int `json:"int"` SpoolMap map[string]int `json:"spool_map"` FreqMat map[int]sparseArray `json:"freq_mat"` } -//MarshalJSON ... +var defaultPrng = rand.New(rand.NewSource(time.Now().UnixNano())) + +// MarshalJSON ... func (chain Chain) MarshalJSON() ([]byte, error) { obj := chainJSON{ chain.Order, @@ -40,7 +48,7 @@ func (chain Chain) MarshalJSON() ([]byte, error) { return json.Marshal(obj) } -//UnmarshalJSON ... +// UnmarshalJSON ... func (chain *Chain) UnmarshalJSON(b []byte) error { var obj chainJSON err := json.Unmarshal(b, &obj) @@ -61,7 +69,7 @@ func (chain *Chain) UnmarshalJSON(b []byte) error { return nil } -//NewChain creates an instance of Chain +// NewChain creates an instance of Chain func NewChain(order int) *Chain { chain := Chain{Order: order} chain.statePool = &spool{ @@ -73,7 +81,7 @@ func NewChain(order int) *Chain { return &chain } -//Add adds the transition counts to the chain for a given sequence of words +// Add adds the transition counts to the chain for a given sequence of words func (chain *Chain) Add(input []string) { startTokens := array(StartToken, chain.Order) endTokens := array(EndToken, chain.Order) @@ -95,7 +103,7 @@ func (chain *Chain) Add(input []string) { } } -//TransitionProbability returns the transition probability between two states +// TransitionProbability returns the transition probability between two states func (chain *Chain) TransitionProbability(next string, current NGram) (float64, error) { if len(current) != chain.Order { return 0, errors.New("N-gram length does not match chain order") @@ -111,7 +119,7 @@ func (chain *Chain) TransitionProbability(next string, current NGram) (float64, return freq / sum, nil } -//Generate generates new text based on an initial seed of words +// Generate generates new text based on an initial seed of words func (chain *Chain) Generate(current NGram) (string, error) { if len(current) != chain.Order { return "", errors.New("N-gram length does not match chain order") @@ -136,6 +144,30 @@ func (chain *Chain) Generate(current NGram) (string, error) { return "", nil } -func init() { - rand.Seed(time.Now().UnixNano()) +// GenerateDeterministic generates new text deterministically, based on an initial seed of words and using a specified PRNG. +// Use it for reproducibly pseudo-random results (i.e. pass the same PRNG and same state every time). +func (chain *Chain) GenerateDeterministic(current NGram, prng PRNG) (string, error) { + if len(current) != chain.Order { + return "", errors.New("N-gram length does not match chain order") + } + if current[len(current)-1] == EndToken { + // Dont generate anything after the end token + return "", nil + } + currentIndex, currentExists := chain.statePool.get(current.key()) + if !currentExists { + return "", fmt.Errorf("Unknown ngram %v", current) + } + arr := chain.frequencyMat[currentIndex] + sum := arr.sum() + randN := prng.Intn(sum) + keys := arr.orderedKeys() + for _, key := range keys { + freq := arr[key] + randN -= freq + if randN <= 0 { + return chain.statePool.intMap[key], nil + } + } + return "", nil } diff --git a/gomarkov_test.go b/gomarkov_test.go index 35b94a6..bcc0685 100644 --- a/gomarkov_test.go +++ b/gomarkov_test.go @@ -1,6 +1,7 @@ package gomarkov import ( + "math/rand" "reflect" "testing" ) @@ -240,3 +241,32 @@ func TestChain_Generate(t *testing.T) { }) } } + +func TestChain_GenerateDeterministic(t *testing.T) { + chain := NewChain(2) + chain.Add(NGram{"i", "like", "bees"}) + chain.Add(NGram{"i", "like", "cake"}) + chain.Add(NGram{"i", "like", "pizza"}) + chain.Add(NGram{"i", "like", "tacos"}) + + pairs := map[int64]string{ + 0: "cake", + 1: "bees", + 10: "cake", + 100: "pizza", + 1000: "bees", + } + for seed, expected := range pairs { + for i := 0; i < 16; i++ { + prng := rand.New(rand.NewSource(seed)) + got, err := chain.GenerateDeterministic(NGram{"i", "like"}, prng) + if err != nil { + panic(err) // you wrote a bad test and should feel bad + } + if got != expected { + t.Errorf("Chain.GenerateDeterministic() is not deterministic; seed = %d, got = %q, want %q", seed, got, expected) + break + } + } + } +} diff --git a/helpers.go b/helpers.go index 644faac..7c18840 100644 --- a/helpers.go +++ b/helpers.go @@ -1,14 +1,17 @@ package gomarkov -import "strings" +import ( + "sort" + "strings" +) -//Pair is a pair of consecutive states in a sequece +// Pair is a pair of consecutive states in a sequece type Pair struct { CurrentState NGram // n = order of the chain NextState string // n = 1 } -//NGram is a array of words +// NGram is a array of words type NGram []string type sparseArray map[int]int @@ -17,6 +20,15 @@ func (ngram NGram) key() string { return strings.Join(ngram, "_") } +func (s sparseArray) orderedKeys() []int { + keys := make([]int, 0, len(s)) + for k := range s { + keys = append(keys, k) + } + sort.Ints(keys) + return keys +} + func (s sparseArray) sum() int { sum := 0 for _, count := range s { @@ -40,7 +52,7 @@ func array(value string, count int) []string { return arr } -//MakePairs generates n-gram pairs of consecutive states in a sequence +// MakePairs generates n-gram pairs of consecutive states in a sequence func MakePairs(tokens []string, order int) []Pair { var pairs []Pair for i := 0; i < len(tokens)-order; i++ { From 3f463a795f35e5ebd1c4845e36a0afc9ed01bdf7 Mon Sep 17 00:00:00 2001 From: Ryota Kayanuma Date: Sat, 9 Feb 2019 14:12:47 +0900 Subject: [PATCH 2/6] Add GenerateAll function --- gomarkov.go | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/gomarkov.go b/gomarkov.go index 4dce351..ace0df9 100644 --- a/gomarkov.go +++ b/gomarkov.go @@ -171,3 +171,26 @@ func (chain *Chain) GenerateDeterministic(current NGram, prng PRNG) (string, err } return "", nil } + +//GenerateAll generates whole chain of text from scratch. +func (chain *Chain) GenerateAll() ([]string, error) { + generatedText := []string{} + current := make(NGram, 0) + for i := 0; i < chain.Order; i++ { + current = append(current, StartToken) + } + + for { + next, err := chain.Generate(current) + if err != nil { + return []string{}, err + } + if next == EndToken { + break + } + + current = append(current, next)[1:] + generatedText = append(generatedText, next) + } + return generatedText, nil +} From 23fc560aac462e80868f3099fa7eba5c5aa8c1d8 Mon Sep 17 00:00:00 2001 From: Boris Nagaev Date: Tue, 1 Aug 2023 17:08:46 -0300 Subject: [PATCH 3/6] use GenerateDeterministic in Generate Fix copy-paste. --- gomarkov.go | 22 +--------------------- 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/gomarkov.go b/gomarkov.go index ace0df9..81674ed 100644 --- a/gomarkov.go +++ b/gomarkov.go @@ -121,27 +121,7 @@ func (chain *Chain) TransitionProbability(next string, current NGram) (float64, // Generate generates new text based on an initial seed of words func (chain *Chain) Generate(current NGram) (string, error) { - if len(current) != chain.Order { - return "", errors.New("N-gram length does not match chain order") - } - if current[len(current)-1] == EndToken { - // Dont generate anything after the end token - return "", nil - } - currentIndex, currentExists := chain.statePool.get(current.key()) - if !currentExists { - return "", fmt.Errorf("Unknown ngram %v", current) - } - arr := chain.frequencyMat[currentIndex] - sum := arr.sum() - randN := rand.Intn(sum) - for i, freq := range arr { - randN -= freq - if randN <= 0 { - return chain.statePool.intMap[i], nil - } - } - return "", nil + return chain.GenerateDeterministic(current, defaultPrng) } // GenerateDeterministic generates new text deterministically, based on an initial seed of words and using a specified PRNG. From efcb9b47101842f03e3d549fe5b3ab58e3be5026 Mon Sep 17 00:00:00 2001 From: Boris Nagaev Date: Fri, 4 Aug 2023 14:35:29 -0300 Subject: [PATCH 4/6] go fmt --- examples/gibberish/gibberish.go | 2 +- gomarkov.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/gibberish/gibberish.go b/examples/gibberish/gibberish.go index b4e3d54..90a2270 100644 --- a/examples/gibberish/gibberish.go +++ b/examples/gibberish/gibberish.go @@ -40,7 +40,7 @@ func main() { return } score := sequenceProbablity(model.Chain, *username) - normalizedScore := (score - model.Mean) / model.StdDev + normalizedScore := (score - model.Mean) / model.StdDev isGibberish := normalizedScore < 0 fmt.Printf("Score: %f | Gibberish: %t\n", normalizedScore, isGibberish) } diff --git a/gomarkov.go b/gomarkov.go index 81674ed..70a1cd6 100644 --- a/gomarkov.go +++ b/gomarkov.go @@ -152,7 +152,7 @@ func (chain *Chain) GenerateDeterministic(current NGram, prng PRNG) (string, err return "", nil } -//GenerateAll generates whole chain of text from scratch. +// GenerateAll generates whole chain of text from scratch. func (chain *Chain) GenerateAll() ([]string, error) { generatedText := []string{} current := make(NGram, 0) From 01c802950867895c9f2cbbe7850dbff75b495d73 Mon Sep 17 00:00:00 2001 From: Boris Nagaev Date: Fri, 4 Aug 2023 14:35:52 -0300 Subject: [PATCH 5/6] add benchmark for generator --- go.mod | 5 ++++- go.sum | 16 ++++++++++++++++ gomarkov_test.go | 25 +++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 874731f..5703706 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,7 @@ module github.com/mb-14/gomarkov go 1.14 -require github.com/montanaflynn/stats v0.6.3 +require ( + github.com/montanaflynn/stats v0.6.3 + github.com/stretchr/testify v1.8.4 +) diff --git a/go.sum b/go.sum index bad6a1b..b012c43 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,18 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/montanaflynn/stats v0.6.3 h1:F8446DrvIF5V5smZfZ8K9nrmmix0AFgevPdLruGOmzk= github.com/montanaflynn/stats v0.6.3/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/gomarkov_test.go b/gomarkov_test.go index bcc0685..396ec14 100644 --- a/gomarkov_test.go +++ b/gomarkov_test.go @@ -1,9 +1,13 @@ package gomarkov import ( + "encoding/json" + "io/ioutil" "math/rand" "reflect" "testing" + + "github.com/stretchr/testify/require" ) func TestChain_MarshalJSON(t *testing.T) { @@ -270,3 +274,24 @@ func TestChain_GenerateDeterministic(t *testing.T) { } } } + +func BenchmarkChain_GenerateDeterministic(b *testing.B) { + data, err := ioutil.ReadFile("test_model.json") + require.NoError(b, err) + var chain Chain + require.NoError(b, json.Unmarshal(data, &chain)) + b.ResetTimer() + const seed = 100 + for i := 0; i < b.N; i++ { + prng := rand.New(rand.NewSource(seed)) + tokens := []string{StartToken} + for count := 0; count <= 100; count++ { + next, err := chain.GenerateDeterministic(tokens, prng) + require.NoError(b, err) + if next == EndToken { + next = StartToken + } + tokens = []string{next} + } + } +} From 68aee51b89bb066cc5ba2c7aefb776f87ba14400 Mon Sep 17 00:00:00 2001 From: Boris Nagaev Date: Fri, 4 Aug 2023 16:33:29 -0300 Subject: [PATCH 6/6] optimize generator by caching sum and ordered keys BenchmarkChain_GenerateDeterministic: before: 635533 ns/op after: 89687 ns/op Time savings: 86%. --- gomarkov.go | 52 ++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/gomarkov.go b/gomarkov.go index 70a1cd6..0787c62 100644 --- a/gomarkov.go +++ b/gomarkov.go @@ -16,11 +16,18 @@ const ( EndToken = "^" ) +type preprocessedArray struct { + sparseArray + + sum int + orderedKeys []int +} + // Chain is a markov chain instance type Chain struct { Order int statePool *spool - frequencyMat map[int]sparseArray + frequencyMat map[int]preprocessedArray lock *sync.RWMutex } @@ -40,10 +47,14 @@ var defaultPrng = rand.New(rand.NewSource(time.Now().UnixNano())) // MarshalJSON ... func (chain Chain) MarshalJSON() ([]byte, error) { + frequencyMat := make(map[int]sparseArray, len(chain.frequencyMat)) + for k, v := range chain.frequencyMat { + frequencyMat[k] = v.sparseArray + } obj := chainJSON{ chain.Order, chain.statePool.stringMap, - chain.frequencyMat, + frequencyMat, } return json.Marshal(obj) } @@ -64,7 +75,14 @@ func (chain *Chain) UnmarshalJSON(b []byte) error { stringMap: obj.SpoolMap, intMap: intMap, } - chain.frequencyMat = obj.FreqMat + chain.frequencyMat = make(map[int]preprocessedArray, len(obj.FreqMat)) + for k, v := range obj.FreqMat { + chain.frequencyMat[k] = preprocessedArray{ + sparseArray: v, + sum: v.sum(), + orderedKeys: v.orderedKeys(), + } + } chain.lock = new(sync.RWMutex) return nil } @@ -76,7 +94,7 @@ func NewChain(order int) *Chain { stringMap: make(map[string]int), intMap: make(map[int]string), } - chain.frequencyMat = make(map[int]sparseArray, 0) + chain.frequencyMat = make(map[int]preprocessedArray) chain.lock = new(sync.RWMutex) return &chain } @@ -95,10 +113,18 @@ func (chain *Chain) Add(input []string) { currentIndex := chain.statePool.add(pair.CurrentState.key()) nextIndex := chain.statePool.add(pair.NextState) chain.lock.Lock() - if chain.frequencyMat[currentIndex] == nil { - chain.frequencyMat[currentIndex] = make(sparseArray, 0) + pa, has := chain.frequencyMat[currentIndex] + if !has { + pa = preprocessedArray{ + sparseArray: make(sparseArray), + } + } + pa.sparseArray[nextIndex]++ + pa.sum++ + if len(pa.orderedKeys) != len(pa.sparseArray) { + pa.orderedKeys = pa.sparseArray.orderedKeys() } - chain.frequencyMat[currentIndex][nextIndex]++ + chain.frequencyMat[currentIndex] = pa chain.lock.Unlock() } } @@ -114,8 +140,8 @@ func (chain *Chain) TransitionProbability(next string, current NGram) (float64, return 0, nil } arr := chain.frequencyMat[currentIndex] - sum := float64(arr.sum()) - freq := float64(arr[nextIndex]) + sum := float64(arr.sum) + freq := float64(arr.sparseArray[nextIndex]) return freq / sum, nil } @@ -139,11 +165,9 @@ func (chain *Chain) GenerateDeterministic(current NGram, prng PRNG) (string, err return "", fmt.Errorf("Unknown ngram %v", current) } arr := chain.frequencyMat[currentIndex] - sum := arr.sum() - randN := prng.Intn(sum) - keys := arr.orderedKeys() - for _, key := range keys { - freq := arr[key] + randN := prng.Intn(arr.sum) + for _, key := range arr.orderedKeys { + freq := arr.sparseArray[key] randN -= freq if randN <= 0 { return chain.statePool.intMap[key], nil