Skip to content

Commit

Permalink
Fine-tune the sample size for CMRR.py (#136)
Browse files Browse the repository at this point in the history
* Fine-tune sample size for CMRR.py

So here's the idea: the amount of time it takes to run CMRR for small values of days_to_simulate is very small. Which means that we can increase the sample size without having to worry about CMRR taking too long to annoy the user.
Instead of writing a whole bunch of else-if statements, I made a function that calculates the "best" sample size. Specifically, it calculates what sample size would make CMRR run roughly as long as it owuld with sample_size=4 and days_to_simulate=365.
sample_size=4 and days_to_simulate=365 is used as the default, and then if days_to_simulate is lower, sample size is increased to make the total run time approximately constant.

* Update fsrs_simulator.py

* Update fsrs_simulator.py

* Change the coefficients.py

* Fixed a problem where the output was 3 instead of 4.py

* Update fsrs_simulator.py

I don't know how I managed to delete "best" twice today

* Update fsrs_simulator.py

* set loss_aversion to 1 when plotting workload

* Update ylim.py

Don't mind me, I'm just trying things

* Update fsrs_simulator.py

* min -> minimum.py

* Decrease ylim.py

* Update fsrs_simulator.py

Ensure that nothing is plotted above the box with the graph

* bump version

---------

Co-authored-by: Jarrett Ye <jarrett.ye@outlook.com>
  • Loading branch information
Expertium and L-M-Sherlock authored Sep 9, 2024
1 parent f2202b3 commit aef1af1
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 15 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "FSRS-Optimizer"
version = "5.0.7"
version = "5.0.8"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",
Expand Down
40 changes: 26 additions & 14 deletions src/fsrs_optimizer/fsrs_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,19 @@ def sample(
workload_only=False,
):
results = []
if learn_span < 100:
SAMPLE_SIZE = 16
elif learn_span < 365:
SAMPLE_SIZE = 8
else:
SAMPLE_SIZE = 4

def best_sample_size(days_to_simulate):
if days_to_simulate <= 30:
return 45
elif days_to_simulate >= 365:
return 4
else:
a1, a2, a3 = 8.20e-07, 2.41e-03, 1.30e-02
factor = a1 * np.power(days_to_simulate, 2) + a2 * days_to_simulate + a3
default_sample_size = 4
return int(default_sample_size/factor)

SAMPLE_SIZE = best_sample_size(learn_span)

for i in range(SAMPLE_SIZE):
_, _, _, memorized_cnt_per_day, cost_per_day = simulate(
Expand Down Expand Up @@ -422,6 +429,7 @@ def workload_graph(default_params, sampling_size=30):
default_params["deck_size"] / default_params["learn_span"]
)
default_params["review_limit_perday"] = math.inf
default_params["loss_aversion"] = 1
workload = [sample(r=r, workload_only=True, **default_params) for r in R]

# this is for testing
Expand Down Expand Up @@ -513,10 +521,14 @@ def workload_graph(default_params, sampling_size=30):
ax.xaxis.set_tick_params(labelsize=14)
ax.set_xlim(0.7, 0.99)

if max_w >= 4.5 * min_w:
lim = 4.5 * min_w
elif max_w >= 3.5 * min_w:
if max_w >= 3.5 * min_w:
lim = 3.5 * min_w
elif max_w >= 3 * min_w:
lim = 3 * min_w
elif max_w >= 2.5 * min_w:
lim = 2.5 * min_w
elif max_w >= 2 * min_w:
lim = 2 * min_w
else:
lim = 1.1 * max_w

Expand All @@ -527,13 +539,13 @@ def workload_graph(default_params, sampling_size=30):
ax.text(
0.701,
min_w,
"min. workload",
"minimum workload",
ha="left",
va="bottom",
color="black",
fontsize=12,
)
if max_w >= 1.8 * min_w:
if lim >= 1.8 * min_w:
ax.axhline(y=1.5 * min_w, color="black", alpha=0.75, ls="--")
ax.text(
0.701,
Expand All @@ -544,7 +556,7 @@ def workload_graph(default_params, sampling_size=30):
color="black",
fontsize=12,
)
if max_w >= 2.3 * min_w:
if lim >= 2.3 * min_w:
ax.axhline(y=2 * min_w, color="black", alpha=0.75, ls="--")
ax.text(
0.701,
Expand All @@ -555,7 +567,7 @@ def workload_graph(default_params, sampling_size=30):
color="black",
fontsize=12,
)
if max_w >= 2.8 * min_w:
if lim >= 2.8 * min_w:
ax.axhline(y=2.5 * min_w, color="black", alpha=0.75, ls="--")
ax.text(
0.701,
Expand All @@ -566,7 +578,7 @@ def workload_graph(default_params, sampling_size=30):
color="black",
fontsize=12,
)
if max_w >= 3.3 * min_w:
if lim >= 3.3 * min_w:
ax.axhline(y=3 * min_w, color="black", alpha=0.75, ls="--")
ax.text(
0.701,
Expand Down

0 comments on commit aef1af1

Please sign in to comment.