From c7568aeb95cf1e0f942fdea08acca87bb46b982e Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Sun, 8 Sep 2024 22:14:50 +0800 Subject: [PATCH] Fix/just workload in workload graph (#135) * Fix/just workload in workload graph * Update fsrs_simulator.py * Update fsrs_simulator.py * bump version --- pyproject.toml | 2 +- src/fsrs_optimizer/fsrs_simulator.py | 81 ++++++++++++++-------------- 2 files changed, 42 insertions(+), 41 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d56f263..3833775 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "FSRS-Optimizer" -version = "5.0.6" +version = "5.0.7" readme = "README.md" dependencies = [ "matplotlib>=3.7.0", diff --git a/src/fsrs_optimizer/fsrs_simulator.py b/src/fsrs_optimizer/fsrs_simulator.py index b0b754d..4533340 100644 --- a/src/fsrs_optimizer/fsrs_simulator.py +++ b/src/fsrs_optimizer/fsrs_simulator.py @@ -260,8 +260,9 @@ def sample( forget_rating_offset=DEFAULT_FORGET_RATING_OFFSET, forget_session_len=DEFAULT_FORGET_SESSION_LEN, loss_aversion=2.5, + workload_only=False, ): - memorization = [] + results = [] if learn_span < 100: SAMPLE_SIZE = 16 elif learn_span < 365: @@ -290,8 +291,11 @@ def sample( loss_aversion, seed=42 + i, ) - memorization.append(cost_per_day.sum() / memorized_cnt_per_day[-1]) - return np.mean(memorization) + if workload_only: + results.append(cost_per_day.sum()) + else: + results.append(cost_per_day.sum() / memorized_cnt_per_day[-1]) + return np.mean(results) def brent(tol=0.01, maxiter=20, **kwargs): @@ -411,53 +415,50 @@ def brent(tol=0.01, maxiter=20, **kwargs): raise Exception("The algorithm terminated without finding a valid value.") -def workload_graph(default_params): - R = [x / 100 for x in range(70, 100)] - cost_per_memorization = [sample(r=r, **default_params) for r in R] +def workload_graph(default_params, sampling_size=30): + R = np.linspace(0.7, 0.999, sampling_size).tolist() + default_params["max_cost_perday"] = math.inf + default_params["learn_limit_perday"] = int( + default_params["deck_size"] / default_params["learn_span"] + ) + default_params["review_limit_perday"] = math.inf + workload = [sample(r=r, workload_only=True, **default_params) for r in R] # this is for testing - # cost_per_memorization = [min(x, 2.3 * min(cost_per_memorization)) for x in cost_per_memorization] - min_w = min(cost_per_memorization) # minimum workload - max_w = max(cost_per_memorization) # maximum workload - min1_index = R.index(R[cost_per_memorization.index(min_w)]) + # workload = [min(x, 2.3 * min(workload)) for x in workload] + min_w = min(workload) # minimum workload + max_w = max(workload) # maximum workload + min1_index = R.index(R[workload.index(min_w)]) min_w2 = 0 min_w3 = 0 target2 = 2 * min_w target3 = 3 * min_w - for i in range(len(cost_per_memorization) - 1): - if (cost_per_memorization[i] <= target2) and ( - cost_per_memorization[i + 1] >= target2 - ): - if abs(cost_per_memorization[i] - target2) < abs( - cost_per_memorization[i + 1] - target2 - ): - min_w2 = cost_per_memorization[i] + for i in range(len(workload) - 1): + if (workload[i] <= target2) and (workload[i + 1] >= target2): + if abs(workload[i] - target2) < abs(workload[i + 1] - target2): + min_w2 = workload[i] else: - min_w2 = cost_per_memorization[i + 1] - - for i in range(len(cost_per_memorization) - 1): - if (cost_per_memorization[i] <= target3) and ( - cost_per_memorization[i + 1] >= target3 - ): - if abs(cost_per_memorization[i] - target3) < abs( - cost_per_memorization[i + 1] - target3 - ): - min_w3 = cost_per_memorization[i] + min_w2 = workload[i + 1] + + for i in range(len(workload) - 1): + if (workload[i] <= target3) and (workload[i + 1] >= target3): + if abs(workload[i] - target3) < abs(workload[i + 1] - target3): + min_w3 = workload[i] else: - min_w3 = cost_per_memorization[i + 1] + min_w3 = workload[i + 1] if min_w2 == 0: min2_index = len(R) else: - min2_index = R.index(R[cost_per_memorization.index(min_w2)]) + min2_index = R.index(R[workload.index(min_w2)]) min1_5_index = int(math.ceil((min2_index + 3 * min1_index) / 4)) if min_w3 == 0: min3_index = len(R) else: - min3_index = R.index(R[cost_per_memorization.index(min_w3)]) + min3_index = R.index(R[workload.index(min_w3)]) fig = plt.figure(figsize=(16, 8)) ax = fig.gca() @@ -465,14 +466,14 @@ def workload_graph(default_params): ax.fill_between( x=R[: min1_index + 1], y1=0, - y2=cost_per_memorization[: min1_index + 1], + y2=workload[: min1_index + 1], color="red", alpha=1, ) ax.fill_between( x=R[min1_index : min1_5_index + 1], y1=0, - y2=cost_per_memorization[min1_index : min1_5_index + 1], + y2=workload[min1_index : min1_5_index + 1], color="gold", alpha=1, ) @@ -481,7 +482,7 @@ def workload_graph(default_params): ax.fill_between( x=R[: min1_5_index + 1], y1=0, - y2=cost_per_memorization[: min1_5_index + 1], + y2=workload[: min1_5_index + 1], color="gold", alpha=1, ) @@ -489,21 +490,21 @@ def workload_graph(default_params): ax.fill_between( x=R[min1_5_index : min2_index + 1], y1=0, - y2=cost_per_memorization[min1_5_index : min2_index + 1], + y2=workload[min1_5_index : min2_index + 1], color="limegreen", alpha=1, ) ax.fill_between( x=R[min2_index : min3_index + 1], y1=0, - y2=cost_per_memorization[min2_index : min3_index + 1], + y2=workload[min2_index : min3_index + 1], color="gold", alpha=1, ) ax.fill_between( x=R[min3_index:], y1=0, - y2=cost_per_memorization[min3_index:], + y2=workload[min3_index:], color="red", alpha=1, ) @@ -521,7 +522,7 @@ def workload_graph(default_params): ax.set_ylim(0, lim) ax.set_ylabel("Workload (minutes of study per day)", fontsize=20) - ax.set_xlabel("Retention", fontsize=20) + ax.set_xlabel("Desired Retention", fontsize=20) ax.axhline(y=min_w, color="black", alpha=0.75, ls="--") ax.text( 0.701, @@ -565,7 +566,7 @@ def workload_graph(default_params): color="black", fontsize=12, ) - + fig.tight_layout(h_pad=0, w_pad=0) return fig @@ -635,4 +636,4 @@ def moving_average(data, window_size=365 // 20): ax.set_title("Memorized Count per Day") ax.grid(True) plt.show() - workload_graph(default_params).savefig("workload.png") + workload_graph(default_params, sampling_size=300).savefig("workload.png")