Skip to content

Commit

Permalink
Feat/float delta_t
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Oct 9, 2024
1 parent 297025c commit a8fb1f2
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 44 deletions.
115 changes: 74 additions & 41 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,31 @@
from typing import List, Optional
from datetime import timedelta, datetime
from collections import defaultdict
import statsmodels.api as sm
from statsmodels.nonparametric.smoothers_lowess import lowess
import statsmodels.api as sm # type: ignore
from statsmodels.nonparametric.smoothers_lowess import lowess # type: ignore
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import torch
from torch import nn
from torch import Tensor
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import TimeSeriesSplit
from sklearn.metrics import (
from sklearn.model_selection import TimeSeriesSplit # type: ignore
from sklearn.metrics import ( # type: ignore
root_mean_squared_error,
mean_absolute_error,
mean_absolute_percentage_error,
r2_score,
)
from scipy.optimize import minimize
from scipy.optimize import minimize # type: ignore
from itertools import accumulate
from tqdm.auto import tqdm
from tqdm.auto import tqdm # type: ignore
import warnings

try:
from .fsrs_simulator import *
except:
from fsrs_simulator import *
except ImportError:
from fsrs_simulator import * # type: ignore

warnings.filterwarnings("ignore", category=UserWarning)

Expand Down Expand Up @@ -67,9 +67,10 @@


class FSRS(nn.Module):
def __init__(self, w: List[float]):
def __init__(self, w: List[float], float_delta_t: bool = False):
super(FSRS, self).__init__()
self.w = nn.Parameter(torch.tensor(w, dtype=torch.float32))
self.float_delta_t = float_delta_t

def stability_after_success(
self, state: Tensor, r: Tensor, rating: Tensor
Expand Down Expand Up @@ -128,14 +129,22 @@ def step(self, X: Tensor, state: Tensor) -> Tensor:
r = power_forgetting_curve(X[:, 0], state[:, 0])
short_term = X[:, 0] < 1
success = X[:, 1] > 1
new_s = torch.where(
short_term,
self.stability_short_term(state, X[:, 1]),
new_s = (
torch.where(
short_term,
self.stability_short_term(state, X[:, 1]),
torch.where(
success,
self.stability_after_success(state, r, X[:, 1]),
self.stability_after_failure(state, r),
),
)
if not self.float_delta_t
else torch.where(
success,
self.stability_after_success(state, r, X[:, 1]),
self.stability_after_failure(state, r),
),
)
)
new_d = self.next_d(state, X[:, 1])
new_d = new_d.clamp(1, 10)
Expand Down Expand Up @@ -192,7 +201,7 @@ def lineToTensor(line: str) -> Tensor:
response = line[1].split(",")
tensor = torch.zeros(len(response), 2)
for li, response in enumerate(response):
tensor[li][0] = int(ivl[li])
tensor[li][0] = float(ivl[li])
tensor[li][1] = int(response)
return tensor

Expand Down Expand Up @@ -277,8 +286,9 @@ def __init__(
lr: float = 1e-2,
batch_size: int = 256,
max_seq_len: int = 64,
float_delta_t: bool = False,
) -> None:
self.model = FSRS(init_w)
self.model = FSRS(init_w, float_delta_t)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
self.clipper = ParameterClipper()
self.batch_size = batch_size
Expand All @@ -292,6 +302,7 @@ def __init__(
self.avg_train_losses = []
self.avg_eval_losses = []
self.loss_fn = nn.BCELoss(reduction="none")
self.float_delta_t = float_delta_t

def build_dataset(self, train_set: pd.DataFrame, test_set: Optional[pd.DataFrame]):
self.train_set = BatchDataset(
Expand Down Expand Up @@ -330,6 +341,9 @@ def train(self, verbose: bool = True):
retentions = power_forgetting_curve(delta_ts, stabilities)
loss = self.loss_fn(retentions, labels).sum()
loss.backward()
if self.float_delta_t:
for param in self.model.parameters():
param.grad[:4] = torch.zeros(4)
self.optimizer.step()
self.scheduler.step()
self.model.apply(self.clipper)
Expand Down Expand Up @@ -401,8 +415,8 @@ def plot(self):


class Collection:
def __init__(self, w: List[float]) -> None:
self.model = FSRS(w)
def __init__(self, w: List[float], float_delta_t: bool = False) -> None:
self.model = FSRS(w, float_delta_t)
self.model.eval()

def predict(self, t_history: str, r_history: str):
Expand Down Expand Up @@ -477,8 +491,13 @@ def loss(stability):


class Optimizer:
def __init__(self) -> None:
float_delta_t: bool = False

def __init__(self, float_delta_t: bool = False) -> None:
tqdm.pandas()
self.float_delta_t = float_delta_t
global S_MIN
S_MIN = 1e-6 if float_delta_t else 0.01

def anki_extract(
self,
Expand Down Expand Up @@ -723,7 +742,10 @@ def create_time_series(
)
).to_julian_date()
# df.drop_duplicates(["card_id", "real_days"], keep="first", inplace=True)
df["delta_t"] = df.real_days.diff()
if self.float_delta_t:
df["delta_t"] = df["review_time"].diff().fillna(0) / 1000 / 86400
else:
df["delta_t"] = df.real_days.diff()
df.fillna({"delta_t": 0}, inplace=True)
df["i"] = df.groupby("card_id").cumcount() + 1
df.loc[df["i"] == 1, "delta_t"] = -1
Expand All @@ -741,7 +763,9 @@ def cum_concat(x):
return list(accumulate(x))

t_history_list = df.groupby("card_id", group_keys=False)["delta_t"].apply(
lambda x: cum_concat([[int(max(0, i))] for i in x])
lambda x: cum_concat(
[[max(0, round(i, 6) if self.float_delta_t else int(i))] for i in x]
)
)
df["t_history"] = [
",".join(map(str, item[:-1]))
Expand Down Expand Up @@ -783,20 +807,21 @@ def cum_concat(x):
df["first_rating"] = df["r_history"].map(lambda x: x[0] if len(x) > 0 else "")
df["y"] = df["review_rating"].map(lambda x: {1: 0, 2: 1, 3: 1, 4: 1}[x])

df[df["i"] == 2] = (
df[df["i"] == 2]
.groupby(by=["first_rating"], as_index=False, group_keys=False)
.apply(remove_outliers)
)
df.dropna(inplace=True)
if not self.float_delta_t:
df[df["i"] == 2] = (
df[df["i"] == 2]
.groupby(by=["first_rating"], as_index=False, group_keys=False)
.apply(remove_outliers)
)
df.dropna(inplace=True)

df = df.groupby("card_id", as_index=False, group_keys=False).progress_apply(
remove_non_continuous_rows
)
df = df.groupby("card_id", as_index=False, group_keys=False).progress_apply(
remove_non_continuous_rows
)

df["review_time"] = df["review_time"].astype(int)
df["review_rating"] = df["review_rating"].astype(int)
df["delta_t"] = df["delta_t"].astype(int)
df["delta_t"] = df["delta_t"].astype(float if self.float_delta_t else int)
df["i"] = df["i"].astype(int)
df["t_history"] = df["t_history"].astype(str)
df["r_history"] = df["r_history"].astype(str)
Expand Down Expand Up @@ -965,8 +990,11 @@ def pretrain(self, dataset=None, verbose=True):
)
continue
delta_t = group["delta_t"]
recall = (group["y"]["mean"] * group["y"]["count"] + average_recall * 1) / (
group["y"]["count"] + 1
recall = (
(group["y"]["mean"] * group["y"]["count"] + average_recall * 1)
/ (group["y"]["count"] + 1)
if not self.float_delta_t
else group["y"]["mean"]
)
count = group["y"]["count"]

Expand All @@ -978,7 +1006,7 @@ def loss(stability):
-(recall * np.log(y_pred) + (1 - recall) * np.log(1 - y_pred))
* count
)
l1 = np.abs(stability - init_s0) / 16
l1 = np.abs(stability - init_s0) / 16 if not self.float_delta_t else 0
return logloss + l1

res = minimize(
Expand Down Expand Up @@ -1152,6 +1180,7 @@ def train(
n_epoch=n_epoch,
lr=lr,
batch_size=batch_size,
float_delta_t=self.float_delta_t,
)
w.append(trainer.train(verbose=verbose))
self.w = w[-1]
Expand All @@ -1171,6 +1200,7 @@ def train(
n_epoch=n_epoch,
lr=lr,
batch_size=batch_size,
float_delta_t=self.float_delta_t,
)
w.append(trainer.train(verbose=verbose))
if verbose:
Expand All @@ -1185,7 +1215,7 @@ def train(
return plots

def preview(self, requestRetention: float, verbose=False):
my_collection = Collection(self.w)
my_collection = Collection(self.w, self.float_delta_t)
preview_text = "1:again, 2:hard, 3:good, 4:easy\n"
n_learning_steps = 3
for first_rating in (1, 2, 3, 4):
Expand Down Expand Up @@ -1258,7 +1288,7 @@ def preview(self, requestRetention: float, verbose=False):
return preview_text

def preview_sequence(self, test_rating_sequence: str, requestRetention: float):
my_collection = Collection(self.w)
my_collection = Collection(self.w, self.float_delta_t)

t_history = "0"
d_history = "0"
Expand Down Expand Up @@ -1304,7 +1334,7 @@ def preview_sequence(self, test_rating_sequence: str, requestRetention: float):
return preview_text

def predict_memory_states(self):
my_collection = Collection(self.w)
my_collection = Collection(self.w, self.float_delta_t)

stabilities, difficulties = my_collection.batch_predict(self.dataset)
stabilities = map(lambda x: round(x, 2), stabilities)
Expand Down Expand Up @@ -1436,7 +1466,7 @@ def moving_average(data, window_size=365 // 20):
return (fig1, fig2, fig3, fig4, fig5, fig6)

def evaluate(self, save_to_file=True):
my_collection = Collection(DEFAULT_PARAMETER)
my_collection = Collection(DEFAULT_PARAMETER, self.float_delta_t)
stabilities, difficulties = my_collection.batch_predict(self.dataset)
self.dataset["stability"] = stabilities
self.dataset["difficulty"] = difficulties
Expand All @@ -1449,7 +1479,7 @@ def evaluate(self, save_to_file=True):
)
loss_before = self.dataset["log_loss"].mean()

my_collection = Collection(self.w)
my_collection = Collection(self.w, self.float_delta_t)
stabilities, difficulties = my_collection.batch_predict(self.dataset)
self.dataset["stability"] = stabilities
self.dataset["difficulty"] = difficulties
Expand Down Expand Up @@ -1567,18 +1597,21 @@ def to_percent(temp, position):
ax2.legend(lns, labs, loc="lower right")
ax2.grid(linestyle="--")
ax2.yaxis.set_major_formatter(ticker.FuncFormatter(to_percent))
ax2.xaxis.set_major_formatter(ticker.FormatStrFormatter("%d"))
return fig

def formula_analysis(self):
analysis_df = self.dataset[self.dataset["i"] > 2].copy()
analysis_df["tensor"] = analysis_df["tensor"].map(lambda x: x[:-1])
my_collection = Collection(self.w)
my_collection = Collection(self.w, self.float_delta_t)
stabilities, difficulties = my_collection.batch_predict(analysis_df)
analysis_df["last_s"] = stabilities
analysis_df["last_d"] = difficulties
analysis_df["last_delta_t"] = analysis_df["t_history"].map(
lambda x: int(x.split(",")[-1])
lambda x: (
int(x.split(",")[-1])
if not self.float_delta_t
else float(x.split(",")[-1])
)
)
analysis_df["last_r"] = power_forgetting_curve(
analysis_df["delta_t"], analysis_df["last_s"]
Expand Down
6 changes: 3 additions & 3 deletions src/fsrs_optimizer/fsrs_simulator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import numpy as np
from matplotlib import pyplot as plt
from tqdm import trange
from tqdm import trange # type: ignore


DECAY = -0.5
Expand Down Expand Up @@ -273,8 +273,8 @@ def best_sample_size(days_to_simulate):
a1, a2, a3 = 8.20e-07, 2.41e-03, 1.30e-02
factor = a1 * np.power(days_to_simulate, 2) + a2 * days_to_simulate + a3
default_sample_size = 4
return int(default_sample_size/factor)
return int(default_sample_size / factor)

SAMPLE_SIZE = best_sample_size(learn_span)

for i in range(SAMPLE_SIZE):
Expand Down

0 comments on commit a8fb1f2

Please sign in to comment.