Skip to content

Commit

Permalink
save step
Browse files Browse the repository at this point in the history
  • Loading branch information
caillonantoine committed May 4, 2022
1 parent 9455a79 commit 16ea4a9
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
11 changes: 8 additions & 3 deletions rave/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,8 @@ def __init__(self,

self.feature_match = feature_match

self.register_buffer("saved_step", torch.tensor(0))

def configure_optimizers(self):
gen_p = list(self.encoder.parameters())
gen_p += list(self.decoder.parameters())
Expand Down Expand Up @@ -552,6 +554,7 @@ def adversarial_combine(self, score_real, score_fake, mode="hinge"):

def training_step(self, batch, batch_idx):
p = Profiler()
self.saved_step += 1

gen_opt, dis_opt = self.optimizers()
x = batch.unsqueeze(1)
Expand Down Expand Up @@ -704,7 +707,7 @@ def validation_step(self, batch, batch_idx):
def validation_epoch_end(self, out):
audio, z = list(zip(*out))

if self.global_step > self.warmup:
if self.saved_step > self.warmup:
self.warmed_up = True

# LATENT SPACE ANALYSIS
Expand All @@ -728,8 +731,10 @@ def validation_epoch_end(self, out):

var_percent = [.8, .9, .95, .99]
for p in var_percent:
self.log(f"{p}%_manifold", np.argmax(var > p).astype(np.float32))
self.log(f"{p}%_manifold",
np.argmax(var > p).astype(np.float32))

y = torch.cat(audio, 0)[:64].reshape(-1)
self.logger.experiment.add_audio("audio_val", y, self.idx, self.sr)
self.logger.experiment.add_audio("audio_val", y,
self.saved_step.item(), self.sr)
self.idx += 1
3 changes: 1 addition & 2 deletions train_rave.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,8 @@ class args(Config):
name="rave"),
gpus=use_gpu,
callbacks=[validation_checkpoint, last_checkpoint],
resume_from_checkpoint=search_for_run(args.CKPT),
max_epochs=100000,
max_steps=args.MAX_STEPS,
**val_check,
)
trainer.fit(model, train, val)
trainer.fit(model, train, val, ckpt_path=search_for_run(args.CKPT))

0 comments on commit 16ea4a9

Please sign in to comment.