Skip to content

Commit

Permalink
Fix/exaggerated initial stability after optimizing (#6)
Browse files Browse the repository at this point in the history
* replace curve_fit with minimize

* update version

* add L1 regularization for init s loss

* replace mse with rmse

* linear decay for l1

* maxiter = int(np.sqrt(total_count)
  • Loading branch information
L-M-Sherlock authored Jul 30, 2023
1 parent 9dc5377 commit c7c1507
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "FSRS-Optimizer"
version = "4.5.3"
version = "4.5.4"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",
Expand Down
5 changes: 4 additions & 1 deletion src/fsrs_optimizer/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ def remembered_fallback_prompt(key: str, pretty: str = None):
with open(args.out, "a+") as f:
f.write(profile)

optimizer.evaluate()
loss_before, loss_after = optimizer.evaluate()
print(f"Loss before training: {loss_before:.4f}")
print(f"Loss after training: {loss_after:.4f}")
if save_graphs:
for i, f in enumerate(optimizer.calibration_graph()):
f.savefig(f"calibration_{i}.png")
Expand Down Expand Up @@ -143,6 +145,7 @@ def remembered_fallback_prompt(key: str, pretty: str = None):
files = [os.path.join(filename, f) for f in files]
for file_path in files:
try:
print(f"Processing {file_path}")
process(file_path)
except Exception as e:
print(e)
Expand Down
27 changes: 20 additions & 7 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from scipy.optimize import curve_fit
from scipy.optimize import curve_fit, minimize
from itertools import accumulate
from tqdm.auto import tqdm
import warnings
Expand Down Expand Up @@ -497,20 +497,34 @@ def pretrain(self, verbose=True):
rating_count = {}
average_recall = self.dataset['y'].mean()
plots = []
s0_size = self.S0_dataset_group.shape[0]
rating_s0 = {
"1": 0.4,
"2": 0.6,
"3": 2.4,
"4": 5.8
}

for first_rating in ("1", "2", "3", "4"):
group = self.S0_dataset_group[self.S0_dataset_group['r_history'] == first_rating]
if group.empty:
tqdm.write(f'Not enough data for first rating {first_rating}. Expected at least 100, got 0.')
tqdm.write(f'Not enough data for first rating {first_rating}. Expected at least 1, got 0.')
continue
delta_t = group['delta_t']
recall = (group['y']['mean'] * group['y']['count'] + average_recall * 1) / (group['y']['count'] + 1)
count = group['y']['count']
total_count = sum(count)
if total_count < 100:
tqdm.write(f'Not enough data for first rating {first_rating}. Expected at least 100, got {total_count}.')
continue
params, _ = curve_fit(power_forgetting_curve, delta_t, recall, sigma=1/np.sqrt(count), bounds=((0.1), (30 if total_count < 1000 else 365)))

init_s0 = rating_s0[first_rating]

def loss(stability):
y_pred = power_forgetting_curve(delta_t, stability)
rmse = np.sqrt(np.sum((recall - y_pred)** 2 * count) / total_count)
l1 = np.abs(stability - init_s0) / np.sqrt(s0_size) / total_count
return rmse + l1

res = minimize(loss, x0=init_s0, bounds=((0.1, 365),), options={"maxiter": int(np.sqrt(total_count))})
params = res.x
stability = params[0]
rating_stability[int(first_rating)] = stability
rating_count[int(first_rating)] = total_count
Expand All @@ -527,7 +541,6 @@ def pretrain(self, verbose=True):
ax.legend(loc='upper right', fancybox=True, shadow=False)
ax.grid(True)
ax.set_ylim(0, 1)
ax.set_xlim(0, 30)
ax.set_xlabel('Interval')
ax.set_ylabel('Recall')
ax.set_title(f'Forgetting curve for first rating {first_rating} (n={total_count}, s={stability:.2f})')
Expand Down

0 comments on commit c7c1507

Please sign in to comment.