Skip to content

Commit

Permalink
can specify dataset directly as an .npz file.
Browse files Browse the repository at this point in the history
  • Loading branch information
cpmpercussion committed Sep 1, 2024
1 parent 5bd18f1 commit 20dd6b3
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions impsy/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import numpy as np
import click
from .utils import mdrnn_config
import os
from pathlib import Path


Expand Down Expand Up @@ -83,9 +82,14 @@ def train_mdrnn(
np.random.seed(SEED)

# Load dataset
dataset_location = f"{dataset_location}/"
dataset_filename = f"training-dataset-{str(dimension)}d.npz"
with np.load(dataset_location + dataset_filename, allow_pickle=True) as loaded:
dataset_location = Path(dataset_location)
dataset_default_name = f"training-dataset-{str(dimension)}d.npz"
if dataset_location.suffix == "":
dataset_default_name = f"training-dataset-{str(dimension)}d.npz"
dataset_location = dataset_location / dataset_default_name
assert dataset_location.suffix == ".npz", "dataset file to load must end with .npz"
click.secho(f"Dataset: {dataset_location}")
with np.load(dataset_location, allow_pickle=True) as loaded:
corpus = loaded["perfs"]
print("Loaded performances:", len(corpus))
print("Num touches:", np.sum([len(l) for l in corpus]))
Expand Down Expand Up @@ -177,7 +181,7 @@ def train_mdrnn(
"--source",
type=str,
default="datasets",
help="The source directory to obtain .npz dataset files.",
help="A .npz dataset file to use for training, or source directory to obtain .npz dataset files.",
)
@click.option(
"-M",
Expand Down

0 comments on commit 20dd6b3

Please sign in to comment.