Skip to content

Commit

Permalink
ignore same-day reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Nov 27, 2024
1 parent 9539e13 commit 42baaaf
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def __init__(
)
self.t_train = torch.tensor(dataframe["delta_t"].values, dtype=torch.float)
self.y_train = torch.tensor(dataframe["y"].values, dtype=torch.float)
self.last_rating = torch.tensor(dataframe["last_rating"].values, dtype=torch.long)
self.seq_len = torch.tensor(
dataframe["tensor"].map(len).values, dtype=torch.long
)
Expand All @@ -252,6 +253,7 @@ def __init__(
sequences_truncated.transpose(0, 1).to(device),
self.t_train[start_index:end_index].to(device),
self.y_train[start_index:end_index].to(device),
self.last_rating[start_index:end_index].to(device),
seq_lens.to(device),
)

Expand Down Expand Up @@ -312,6 +314,7 @@ def __init__(
self.avg_eval_losses = []
self.loss_fn = nn.BCELoss(reduction="none")
self.float_delta_t = float_delta_t
self.pls_penalty = 4

def build_dataset(self, train_set: pd.DataFrame, test_set: Optional[pd.DataFrame]):
self.train_set = BatchDataset(
Expand Down Expand Up @@ -343,14 +346,13 @@ def train(self, verbose: bool = True):
for i, batch in enumerate(self.train_data_loader):
self.model.train()
self.optimizer.zero_grad()
sequences, delta_ts, labels, seq_lens = batch
sequences, delta_ts, labels, last_ratings, seq_lens = batch
real_batch_size = seq_lens.shape[0]
outputs, _ = self.model(sequences)
stabilities = outputs[seq_lens - 1, torch.arange(real_batch_size), 0]
retentions = power_forgetting_curve(delta_ts, stabilities)
pls_flag = sequences[seq_lens-1, torch.arange(real_batch_size), 1] == 1
penalty = torch.ones_like(retentions, requires_grad=False)
penalty[pls_flag] *= 2
penalty[last_ratings == 1] *= self.pls_penalty
loss = (self.loss_fn(retentions, labels) * penalty).sum()
loss.backward()
if self.float_delta_t:
Expand Down Expand Up @@ -387,19 +389,19 @@ def eval(self):
if len(dataset) == 0:
losses.append(0)
continue
sequences, delta_ts, labels, seq_lens = (
sequences, delta_ts, labels, last_ratings, seq_lens = (
dataset.x_train,
dataset.t_train,
dataset.y_train,
dataset.last_rating,
dataset.seq_len,
)
real_batch_size = seq_lens.shape[0]
outputs, _ = self.model(sequences.transpose(0, 1))
stabilities = outputs[seq_lens - 1, torch.arange(real_batch_size), 0]
retentions = power_forgetting_curve(delta_ts, stabilities)
penalty = torch.ones_like(retentions, requires_grad=False)
pls_flag = sequences[torch.arange(real_batch_size), seq_lens-1, 1] == 1
penalty[pls_flag] *= 2
penalty[last_ratings == 1] *= self.pls_penalty
loss = (self.loss_fn(retentions, labels) * penalty).mean()
losses.append(loss)
self.avg_train_losses.append(losses[0])
Expand Down Expand Up @@ -883,7 +885,6 @@ def cum_concat(x):
"real_days",
"review_rating",
"t_history",
"last_rating",
"y",
],
inplace=True,
Expand Down Expand Up @@ -1178,7 +1179,6 @@ def train(
lambda x: lineToTensor(list(zip([x["t_history"]], [x["r_history"]]))[0]),
axis=1,
)
self.dataset["group"] = self.dataset["r_history"] + self.dataset["t_history"]
if verbose:
tqdm.write("Tensorized!")

Expand Down

0 comments on commit 42baaaf

Please sign in to comment.