Skip to content

Commit

Permalink
Merge pull request #16 from f1monkey/feature/score-function
Browse files Browse the repository at this point in the history
Feature/score function
  • Loading branch information
cyradin authored Jun 20, 2024
2 parents 0bfd519 + 8b31113 commit 87dfa2a
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 37 deletions.
43 changes: 40 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@ Yet another spellchecker written in go.
## Installation

```
$ go get -v github.com/f1monkey/spellchecker
go get -v github.com/f1monkey/spellchecker
```

## Usage


### Quick start

```go

func main() {
// Create new instance
sc, err := spellchecker.New(
Expand All @@ -41,7 +45,7 @@ func main() {
}
sc.AddFrom(in)

// Add some more words
// Add some words
sc.Add("lock", "stock", "and", "two", "smoking", "barrels")

// Check if a word is correct
Expand All @@ -61,6 +65,12 @@ func main() {
panic(err)
}
fmt.Println(matches) // [range, orange]
```
### Save/load
```go
sc, err := spellchecker.New("abc")

// Save data to any io.Writer
out, err := os.Create("data/out.bin")
Expand All @@ -78,9 +88,36 @@ func main() {
if err != nil {
panic(err)
}
}
```
### Custom score function
You can provide a custom score function if you need to.
```go
var scoreFunc spellchecker.ScoreFunc = func(src, candidate []rune, distance, cnt int) float64 {
return 1.0 // return constant score
}

sc, err := spellchecker.New("abc", spellchecker.WithScoreFunc(scoreFunc))
if err != nil {
// handle err
}

// after you load spellchecker from file
// you will need to provide the function again:
sc, err = spellchecker.Load(inFile)
if err != nil {
// handle err
}

err = sc.WithOpts(spellchecker.WithScoreFunc(scoreFunc))
if err != nil {
// handle err
}
```
## Benchmarks
Tests are based on data from [Peter Norvig's article about spelling correction](http://norvig.com/spell-correct.html)
Expand Down
27 changes: 9 additions & 18 deletions dictionary.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ import (
"bytes"
"encoding"
"encoding/gob"
"math"
"sort"
"sync/atomic"

"github.com/agnivade/levenshtein"
"github.com/f1monkey/bitmap"
)

type scoreFunc func(src []rune, candidate []rune, distance int, cnt int) float64

type dictionary struct {
maxErrors int
alphabet alphabet
Expand All @@ -22,9 +23,11 @@ type dictionary struct {
counts map[uint32]int

index map[uint64][]uint32

scoreFunc scoreFunc
}

func newDictionary(ab string, maxErrors int) (*dictionary, error) {
func newDictionary(ab string, scoreFunc scoreFunc, maxErrors int) (*dictionary, error) {
alphabet, err := newAlphabet(ab)
if err != nil {
return nil, err
Expand All @@ -38,6 +41,7 @@ func newDictionary(ab string, maxErrors int) (*dictionary, error) {
words: make(map[uint32]string),
counts: make(map[uint32]int),
index: make(map[uint64][]uint32),
scoreFunc: scoreFunc,
}, nil
}

Expand Down Expand Up @@ -111,7 +115,7 @@ func (d *dictionary) getCandidates(word string, max int) []match {
}
result.Push(match{
Value: docWord,
Score: calcScore(wordRunes, []rune(docWord), distance, d.counts[id]),
Score: d.scoreFunc(wordRunes, []rune(docWord), distance, d.counts[id]),
})
}
// the most common mistake is a transposition of letters.
Expand All @@ -135,7 +139,7 @@ func (d *dictionary) getCandidates(word string, max int) []match {
}
result.Push(match{
Value: docWord,
Score: calcScore(wordRunes, []rune(docWord), distance, d.counts[id]),
Score: d.scoreFunc(wordRunes, []rune(docWord), distance, d.counts[id]),
})
}
}
Expand Down Expand Up @@ -180,20 +184,6 @@ func (d *dictionary) computeCandidateBitmaps(bmSrc bitmap.Bitmap32) map[uint64]s
return bitmaps
}

func calcScore(src []rune, candidate []rune, distance int, cnt int) float64 {
mult := math.Log1p(float64(cnt))
// if first letters are the same, increase score
if src[0] == candidate[0] {
mult *= 1.5
// if second letters are the same too, increase score even more
if len(src) > 1 && len(candidate) > 1 && src[1] == candidate[1] {
mult *= 1.5
}
}

return 1 / (1 + float64(distance*distance)) * mult
}

var _ encoding.BinaryMarshaler = (*dictionary)(nil)
var _ encoding.BinaryUnmarshaler = (*dictionary)(nil)

Expand Down Expand Up @@ -240,6 +230,7 @@ func (d *dictionary) UnmarshalBinary(data []byte) error {
d.words = dictData.Words
d.index = dictData.Index
d.maxErrors = dictData.MaxErrors
d.scoreFunc = defaultScorefunc

var max uint32
for _, id := range d.ids {
Expand Down
6 changes: 3 additions & 3 deletions dictionary_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
)

func Test_dictionary_id(t *testing.T) {
dict, err := newDictionary(DefaultAlphabet, DefaultMaxErrors)
dict, err := newDictionary(DefaultAlphabet, defaultScorefunc, DefaultMaxErrors)
require.NoError(t, err)

t.Run("must return 0 for unexisting word", func(t *testing.T) {
Expand All @@ -24,7 +24,7 @@ func Test_dictionary_id(t *testing.T) {

func Test_dictionary_add(t *testing.T) {
t.Run("must add word to dictionary index", func(t *testing.T) {
dict, err := newDictionary(DefaultAlphabet, DefaultMaxErrors)
dict, err := newDictionary(DefaultAlphabet, defaultScorefunc, DefaultMaxErrors)
require.NoError(t, err)

id, err := dict.add("qwe")
Expand All @@ -49,7 +49,7 @@ func Test_dictionary_add(t *testing.T) {

func Test_Dictionary_Inc(t *testing.T) {
t.Run("must increase counter value", func(t *testing.T) {
dict, err := newDictionary(DefaultAlphabet, DefaultMaxErrors)
dict, err := newDictionary(DefaultAlphabet, defaultScorefunc, DefaultMaxErrors)
dict.counts[1] = 0
require.NoError(t, err)

Expand Down
12 changes: 8 additions & 4 deletions save_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"path"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand All @@ -26,7 +25,12 @@ func Test_Spellchecker_Save(t *testing.T) {
m2, err := Load(file)
require.NoError(t, err)

assert.EqualValues(t, m1.dict.id("green"), m2.dict.id("green"))
assert.EqualValues(t, m1.dict.maxErrors, m2.dict.maxErrors)
assert.EqualValues(t, m1.dict.nextID(), m2.dict.nextID())
require.EqualValues(t, m1.dict.id("green"), m2.dict.id("green"))
require.EqualValues(t, m1.dict.maxErrors, m2.dict.maxErrors)
require.EqualValues(t, m1.dict.nextID(), m2.dict.nextID())

matches := m2.dict.find("orange", 1)
require.Len(t, matches, 1)
require.Equal(t, matches[0].Value, "orange")
require.Greater(t, matches[0].Score, 0.0)
}
45 changes: 36 additions & 9 deletions spellchecker.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,41 @@ import (
"bufio"
"fmt"
"io"
"math"
"sync"
)

const DefaultMaxErrors = 2

// OptionFunc option setter
type OptionFunc func(m *Spellchecker) error
type OptionFunc func(s *Spellchecker) error

type Spellchecker struct {
mtx sync.RWMutex

dict *dictionary
splitter bufio.SplitFunc
scoreFunc scoreFunc
maxErrors int
}

func New(alphabet string, opts ...OptionFunc) (*Spellchecker, error) {
result := &Spellchecker{
maxErrors: DefaultMaxErrors,
scoreFunc: defaultScorefunc,
}
dict, err := newDictionary(alphabet, result.scoreFunc, result.maxErrors)
if err != nil {
return nil, err
}
result.dict = dict

for _, o := range opts {
if err := o(result); err != nil {
return nil, err
}
}

dict, err := newDictionary(alphabet, result.maxErrors)
if err != nil {
return nil, err
}
result.dict = dict

return result, nil
}

Expand Down Expand Up @@ -151,8 +154,32 @@ func WithSplitter(f bufio.SplitFunc) OptionFunc {
// WithMaxErrors set maxErrors, which is a max diff in bits betweeen the "search word" and a "dictionary word".
// i.e. one simple symbol replacement (problam => problem ) is a two-bit difference
func WithMaxErrors(maxErrors int) OptionFunc {
return func(m *Spellchecker) error {
m.maxErrors = maxErrors
return func(s *Spellchecker) error {
s.maxErrors = maxErrors
return nil
}
}

type ScoreFunc = scoreFunc

// WithScoreFunc specify a function that will be used for scoring
func WithScoreFunc(f ScoreFunc) OptionFunc {
return func(s *Spellchecker) error {
s.dict.scoreFunc = f
return nil
}
}

var defaultScorefunc scoreFunc = func(src, candidate []rune, distance, cnt int) float64 {
mult := math.Log1p(float64(cnt))
// if first letters are the same, increase score
if src[0] == candidate[0] {
mult *= 1.5
// if second letters are the same too, increase score even more
if len(src) > 1 && len(candidate) > 1 && src[1] == candidate[1] {
mult *= 1.5
}
}

return 1 / (1 + float64(distance*distance)) * mult
}

0 comments on commit 87dfa2a

Please sign in to comment.