diff --git a/fsrs.go b/fsrs.go index ec4bf2a..c9ea253 100644 --- a/fsrs.go +++ b/fsrs.go @@ -1,133 +1,177 @@ package fsrs import ( - "fmt" "math" "time" - - "github.com/ImSingee/go-ex/set" ) -func (globalData *GlobalData) Learn(cardData *CardData, grade Grade) { - if grade < GradeNewCard || grade > GradeEasy { - panic(fmt.Sprintf("Invalid grade: %d", grade)) +func (p *Parameters) Repeat(card Card, now time.Time) map[Rating]SchedulingInfo { + if card.State == New { + card.ElapsedDays = 0 + } else { + card.ElapsedDays = uint64(math.Round(float64(now.Sub(card.LastReview) / time.Hour / 24))) } - - now := time.Now() - - if grade == GradeNewCard { // learn new card - addDay := math.Round(globalData.DefaultStability * math.Log(globalData.RequestRetention) / math.Log(0.9)) - - cardData.Due = now.Add(time.Duration(addDay * float64(24*time.Hour))) - cardData.Interval = 0 - cardData.Difficulty = globalData.DefaultDifficulty - cardData.Stability = globalData.DefaultStability - cardData.Retrievability = 1 - cardData.LastGrade = GradeNewCard - cardData.Review = now - cardData.Reps = 1 - cardData.Lapses = 0 - - return + card.LastReview = now + card.Reps += 1 + s := new(SchedulingCards) + s.init(card) + s.updateState(card.State) + + switch card.State { + case New: + p.initDS(s) + + s.Again.Due = now.Add(1 * time.Minute) + s.Hard.Due = now.Add(5 * time.Minute) + s.Good.Due = now.Add(10 * time.Minute) + easyInterval := p.nextInterval(s.Easy.Stability * p.EasyBonus) + s.Easy.ScheduledDays = uint64(easyInterval) + s.Easy.Due = now.Add(time.Duration(easyInterval) * 24 * time.Hour) + case Learning, Relearning: + hardInterval := p.nextInterval(s.Hard.Stability) + goodInterval := math.Max(p.nextInterval(s.Good.Stability), hardInterval+1) + easyInterval := math.Max(p.nextInterval(s.Easy.Stability*p.EasyBonus), goodInterval+1) + + s.schedule(now, hardInterval, goodInterval, easyInterval) + case Review: + interval := float64(card.ElapsedDays) + lastD := card.Difficulty + lastS := card.Stability + retrievability := math.Exp(math.Log(0.9) * interval / lastS) + p.nextDS(s, lastD, lastS, retrievability) + + hardInterval := p.nextInterval(lastS * p.HardFactor) + goodInterval := p.nextInterval(s.Good.Stability) + hardInterval = math.Min(hardInterval, goodInterval) + goodInterval = math.Max(goodInterval, hardInterval+1) + easyInterval := math.Max(p.nextInterval(s.Easy.Stability*p.EasyBonus), goodInterval+1) + s.schedule(now, hardInterval, goodInterval, easyInterval) } + return s.recordLog(card, now) +} - // review card after learn - lastDifficulty := cardData.Difficulty - lastStability := cardData.Stability - lastLapses := cardData.Lapses - lastReps := cardData.Reps - lastReview := cardData.Review - - h := cardData.CardDataItem - cardData.History = append(cardData.History, &h) - - diffDay := (time.Since(lastReview) / time.Hour / 24) + 1 - if diffDay > 0 { - cardData.Interval = uint64(diffDay) - } else { - cardData.Interval = 0 +func (s *SchedulingCards) updateState(state State) { + switch state { + case New: + s.Again.State = Learning + s.Hard.State = Learning + s.Good.State = Learning + s.Easy.State = Review + s.Again.Lapses += 1 + case Learning, Relearning: + s.Again.State = state + s.Hard.State = Review + s.Good.State = Review + s.Easy.State = Review + case Review: + s.Again.State = Relearning + s.Hard.State = Review + s.Good.State = Review + s.Easy.State = Review + s.Again.Lapses += 1 } +} - cardData.Review = now - cardData.Retrievability = math.Exp(math.Log(0.9) * float64(cardData.Interval) / lastStability) - cardData.Difficulty = math.Min(math.Max(lastDifficulty+cardData.Retrievability-float64(grade)+0.2, 1), 10) - - if grade == GradeForgetting { - cardData.Stability = globalData.DefaultStability * math.Exp(-0.3*float64(lastLapses+1)) +func (s *SchedulingCards) schedule(now time.Time, hardInterval float64, goodInterval float64, easyInterval float64) { + s.Again.ScheduledDays = 0 + s.Hard.ScheduledDays = uint64(hardInterval) + s.Good.ScheduledDays = uint64(goodInterval) + s.Easy.ScheduledDays = uint64(easyInterval) + s.Again.Due = now.Add(5 * time.Minute) + s.Hard.Due = now.Add(time.Duration(hardInterval) * 24 * time.Hour) + s.Good.Due = now.Add(time.Duration(goodInterval) * 24 * time.Hour) + s.Easy.Due = now.Add(time.Duration(easyInterval) * 24 * time.Hour) +} - if lastReps > 1 { - globalData.TotalDiff = globalData.TotalDiff - cardData.Retrievability - } +func (s *SchedulingCards) recordLog(card Card, now time.Time) map[Rating]SchedulingInfo { + m := map[Rating]SchedulingInfo{ + Again: {s.Again, ReviewLog{ + Rating: Again, + ScheduledDays: s.Again.ScheduledDays, + ElapsedDays: card.ElapsedDays, + Review: now, + State: card.State, + }}, + Hard: {s.Hard, ReviewLog{ + Rating: Hard, + ScheduledDays: s.Hard.ScheduledDays, + ElapsedDays: card.ElapsedDays, + Review: now, + State: card.State, + }}, + Good: {s.Good, ReviewLog{ + Rating: Good, + ScheduledDays: s.Good.ScheduledDays, + ElapsedDays: card.ElapsedDays, + Review: now, + State: card.State, + }}, + Easy: {s.Easy, ReviewLog{ + Rating: Easy, + ScheduledDays: s.Easy.ScheduledDays, + ElapsedDays: card.ElapsedDays, + Review: now, + State: card.State, + }}, + } + return m +} - cardData.Lapses = lastLapses + 1 - cardData.Reps = 1 +func (p *Parameters) initDS(s *SchedulingCards) { + s.Again.Difficulty = p.initDifficulty(Again) + s.Again.Stability = p.initStability(Again) + s.Hard.Difficulty = p.initDifficulty(Hard) + s.Hard.Stability = p.initStability(Hard) + s.Good.Difficulty = p.initDifficulty(Good) + s.Good.Stability = p.initStability(Good) + s.Easy.Difficulty = p.initDifficulty(Easy) + s.Easy.Stability = p.initStability(Easy) +} - } else { //grade == 1 || grade == 2 - cardData.Stability = lastStability * (1 + globalData.IncreaseFactor*math.Pow(cardData.Difficulty, globalData.DifficultyDecay)*math.Pow(lastStability, globalData.StabilityDecay)*(math.Exp(1-cardData.Retrievability)-1)) +func (p *Parameters) nextDS(s *SchedulingCards, lastD float64, lastS float64, retrievability float64) { + s.Again.Difficulty = p.nextDifficulty(lastD, Again) + s.Again.Stability = p.nextForgetStability(s.Again.Difficulty, lastS, retrievability) + s.Hard.Difficulty = p.nextDifficulty(lastD, Hard) + s.Hard.Stability = p.nextRecallStability(s.Hard.Difficulty, lastS, retrievability) + s.Good.Difficulty = p.nextDifficulty(lastD, Good) + s.Good.Stability = p.nextRecallStability(s.Good.Difficulty, lastS, retrievability) + s.Easy.Difficulty = p.nextDifficulty(lastD, Easy) + s.Easy.Stability = p.nextRecallStability(s.Easy.Difficulty, lastS, retrievability) +} - if lastReps > 1 { - globalData.TotalDiff = globalData.TotalDiff + 1 - cardData.Retrievability - } +func (p *Parameters) initStability(r Rating) float64 { + return math.Max(p.W[0]+p.W[1]*float64(r), 0.1) +} +func (p *Parameters) initDifficulty(r Rating) float64 { + return constrainDifficulty(p.W[2] + p.W[3]*float64(r-2)) +} - cardData.Lapses = lastLapses - cardData.Reps = lastReps + 1 - } +func constrainDifficulty(d float64) float64 { + return math.Min(math.Max(d, 1), 10) +} - globalData.TotalCase++ - globalData.TotalReview++ +func (p *Parameters) nextInterval(s float64) float64 { + newInterval := s * math.Log(p.RequestRetention) / math.Log(0.9) + return math.Max(math.Min(math.Round(newInterval), p.MaximumInterval), 1) +} - addDay := math.Round(cardData.Stability * math.Log(globalData.RequestRetention) / math.Log(0.9)) - cardData.Due = now.Add(time.Duration(addDay * float64(24*time.Hour))) +func (p *Parameters) nextDifficulty(d float64, r Rating) float64 { + nextD := d + p.W[4]*float64(r-2) + return constrainDifficulty(p.meanReversion(p.W[2], nextD)) +} - // Adaptive globalData.defaultDifficulty - if globalData.TotalCase > 100 { - globalData.DefaultDifficulty = 1.0/math.Pow(float64(globalData.TotalReview), 0.3)*math.Pow(math.Log(globalData.RequestRetention)/math.Max(math.Log(globalData.RequestRetention+globalData.TotalDiff/float64(globalData.TotalCase)), 0), 1/globalData.DifficultyDecay)*5 + (1-1/math.Pow(float64(globalData.TotalReview), 0.3))*globalData.DefaultDifficulty +func (p *Parameters) meanReversion(init float64, current float64) float64 { + return p.W[5]*init + (1-p.W[5])*current +} - globalData.TotalDiff = 0 - globalData.TotalCase = 0 - } +func (p *Parameters) nextRecallStability(d float64, s float64, r float64) float64 { + return s * (1 + math.Exp(p.W[6])* + (11-d)* + math.Pow(s, p.W[7])* + (math.Exp((1-r)*p.W[8])-1)) +} - // Adaptive globalData.defaultStability - if lastReps == 1 && lastLapses == 0 { - retrievability := uint64(0) - if grade > GradeForgetting { - retrievability = 1 - } - globalData.StabilityDataArray = append(globalData.StabilityDataArray, &StabilityData{ - Interval: cardData.Interval, - Retrievability: retrievability, - }) - - if len(globalData.StabilityDataArray) > 0 && len(globalData.StabilityDataArray)%50 == 0 { - intervalSetArray := set.New[uint64]() - - sumRI2S := float64(0) - sumI2S := float64(0) - - for s := 0; s < len(globalData.StabilityDataArray); s++ { - ivl := globalData.StabilityDataArray[s].Interval - - if !intervalSetArray.Has(ivl) { - intervalSetArray.Add(ivl) - - retrievabilitySum := uint64(0) - currentCount := 0 - for _, fi := range globalData.StabilityDataArray { - if fi.Interval == ivl { - retrievabilitySum += fi.Retrievability - currentCount++ - } - } - - if retrievabilitySum > 0 { - sumRI2S = sumRI2S + float64(ivl)*math.Log(float64(retrievabilitySum)/float64(currentCount))*float64(currentCount) - sumI2S = sumI2S + float64(ivl*ivl)*float64(currentCount) - } - } - - } - - globalData.DefaultStability = (math.Max(math.Log(0.9)/(sumRI2S/sumI2S), 0.1) + globalData.DefaultStability) / 2 - } - } +func (p *Parameters) nextForgetStability(d float64, s float64, r float64) float64 { + return p.W[9] * math.Pow(d, p.W[10]) * math.Pow( + s, p.W[11]) * math.Exp((1-r)*p.W[12]) } diff --git a/fsrs_test.go b/fsrs_test.go new file mode 100644 index 0000000..57597c9 --- /dev/null +++ b/fsrs_test.go @@ -0,0 +1,40 @@ +package fsrs + +import ( + "encoding/json" + "testing" + "time" +) + +func TestRepeat(t *testing.T) { + p := DefaultParam() + card := NewCard() + now := time.Date(2022, 11, 29, 12, 30, 0, 0, time.UTC) + schedulingCards := p.Repeat(card, now) + schedule, _ := json.Marshal(schedulingCards) + t.Logf(string(schedule)) + + card = schedulingCards[Good].Card + now = card.Due + schedulingCards = p.Repeat(card, now) + schedule, _ = json.Marshal(schedulingCards) + t.Logf(string(schedule)) + + card = schedulingCards[Good].Card + now = card.Due + schedulingCards = p.Repeat(card, now) + schedule, _ = json.Marshal(schedulingCards) + t.Logf(string(schedule)) + + card = schedulingCards[Again].Card + now = card.Due + schedulingCards = p.Repeat(card, now) + schedule, _ = json.Marshal(schedulingCards) + t.Logf(string(schedule)) + + card = schedulingCards[Good].Card + now = card.Due + schedulingCards = p.Repeat(card, now) + schedule, _ = json.Marshal(schedulingCards) + t.Logf(string(schedule)) +} diff --git a/go.mod b/go.mod index aa2b3b6..39a9ed2 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,3 @@ module github.com/open-spaced-repetition/go-fsrs -go 1.18 - -require github.com/ImSingee/go-ex v0.4.10 +go 1.19 diff --git a/go.sum b/go.sum deleted file mode 100644 index cfd93f1..0000000 --- a/go.sum +++ /dev/null @@ -1,2 +0,0 @@ -github.com/ImSingee/go-ex v0.4.10 h1:ZJA6J8NbU7Bnwu8uTAEMBxZ7n7SzCPCztzJHi+ZVCqs= -github.com/ImSingee/go-ex v0.4.10/go.mod h1:CNc3Fqk9GkQfm/1x53vGnQ0BEHH+52siqBJFkVwdy8U= diff --git a/models.go b/models.go index be6b570..98c5378 100644 --- a/models.go +++ b/models.go @@ -1,69 +1,114 @@ package fsrs -import "time" - -type GlobalData struct { - DifficultyDecay float64 `json:"difficultyDecay"` - StabilityDecay float64 `json:"stabilityDecay"` - IncreaseFactor float64 `json:"increaseFactor"` - RequestRetention float64 `json:"requestRetention"` - TotalCase uint64 `json:"totalCase"` - TotalDiff float64 `json:"totalDiff"` - TotalReview uint64 `json:"totalReview"` - DefaultDifficulty float64 `json:"defaultDifficulty"` - DefaultStability float64 `json:"defaultStability"` - StabilityDataArray []*StabilityData `json:"stabilityDataArray"` +import ( + "time" +) + +type weights [13]float64 + +func defaultWeights() weights { + return weights{1, 1, 5, -0.5, -0.5, 0.2, 1.4, -0.12, 0.8, 2, -0.2, 0.2, 1} } -type StabilityData struct { - Interval uint64 - Retrievability uint64 +type Parameters struct { + RequestRetention float64 + MaximumInterval float64 + EasyBonus float64 + HardFactor float64 + W weights } -type CardData struct { - CardDataItem +func DefaultParam() Parameters { + return Parameters{ + RequestRetention: 0.9, + MaximumInterval: 36500, + EasyBonus: 1.3, + HardFactor: 1.2, + W: defaultWeights(), + } +} - History []*CardDataItem +type Card struct { + Due time.Time `json:"Due"` + Stability float64 `json:"Stability"` + Difficulty float64 `json:"Difficulty"` + ElapsedDays uint64 `json:"ElapsedDays"` + ScheduledDays uint64 `json:"ScheduledDays"` + Reps uint64 `json:"Reps"` + Lapses uint64 `json:"Lapses"` + State State `json:"State"` + LastReview time.Time `json:"LastReview"` } -type CardDataItem struct { - Due time.Time - Interval uint64 // 上次复习间隔(单位为天) - Difficulty float64 - Stability float64 - Retrievability float64 - LastGrade Grade // 上次得分 - Review time.Time - Reps uint64 - Lapses uint64 +func NewCard() Card { + return Card{ + Due: time.Time{}, + Stability: 0, + Difficulty: 0, + ElapsedDays: 0, + ScheduledDays: 0, + Reps: 0, + Lapses: 0, + State: New, + LastReview: time.Time{}, + } } -func (item *CardDataItem) Copy() CardDataItem { - return *item +type ReviewLog struct { + Rating Rating `json:"Rating"` + ScheduledDays uint64 `json:"ScheduledDays"` + ElapsedDays uint64 `json:"ElapsedDays"` + Review time.Time `json:"Review"` + State State `json:"State"` } -type Grade int8 +type SchedulingCards struct { + Again Card + Hard Card + Good Card + Easy Card +} -const ( - GradeForgetting Grade = iota - GradeRemembered - GradeEasy +func (s *SchedulingCards) init(card Card) { + s.Again = card + s.Hard = card + s.Good = card + s.Easy = card +} - GradeNewCard Grade = -1 +type SchedulingInfo struct { + Card Card + ReviewLog ReviewLog +} + +type Rating int8 + +const ( + Again Rating = iota + Hard + Good + Easy ) -// DefaultGlobalData returns the default values of GlobalData -func DefaultGlobalData() GlobalData { - return GlobalData{ - DifficultyDecay: -0.7, - StabilityDecay: 0.2, - IncreaseFactor: 60, - RequestRetention: 0.9, - TotalCase: 0, - TotalDiff: 0, - TotalReview: 0, - DefaultDifficulty: 5, - DefaultStability: 2, - StabilityDataArray: nil, +func (s Rating) String() string { + switch s { + case Again: + return "Again" + case Hard: + return "Hard" + case Good: + return "Good" + case Easy: + return "Easy" } + return "unknown" } + +type State int8 + +const ( + New State = iota + Learning + Review + Relearning +)