Skip to content

Commit

Permalink
cleaned up logging
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Fuest committed Oct 3, 2024
1 parent f747c70 commit 1630e7e
Show file tree
Hide file tree
Showing 10 changed files with 154 additions and 185 deletions.
5 changes: 3 additions & 2 deletions config/model_config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
device: 1 # 0, cpu
seq_len: 96 # should not be changed for the current datasets
input_dim: 2 # or 1 depending on user, but is dynamically set
noise_dim: 256
Expand Down Expand Up @@ -66,7 +67,7 @@ diffusion_ts:

acgan:
batch_size: 32
n_epochs: 200
n_epochs: 10
lr_gen: 3e-4
lr_discr: 1e-4
warm_up_epochs: 50
warm_up_epochs: 5
1 change: 0 additions & 1 deletion datasets/openpower.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

warnings.filterwarnings("ignore", category=pd.errors.SettingWithCopyWarning)
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class OpenPowerDataset(Dataset):
Expand Down
1 change: 0 additions & 1 deletion datasets/pecanstreet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

warnings.filterwarnings("ignore", category=pd.errors.SettingWithCopyWarning)
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class PecanStreetDataManager:
Expand Down
42 changes: 42 additions & 0 deletions datasets/timeseries_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch
from torch.utils.data import Dataset


class TimeSeriesDataset(Dataset):
def __init__(self, dataframe, time_series_column_name, conditioning_vars=None):
"""
Initializes the TimeSeriesDataset.
Args:
dataframe (pd.DataFrame): The input DataFrame containing the time series data and optional conditioning variables.
time_series_column_name (str): The name of the column containing the time series data.
conditioning_vars (list of str, optional): List of column names to be used as conditioning variables.
"""
self.data = dataframe.reset_index(drop=True)
self.conditioning_vars = conditioning_vars or []
self.time_series_column_name = time_series_column_name

if self.time_series_column_name not in self.data.columns:
raise ValueError(
f"Time series column '{self.time_series_column_name}' not found in DataFrame."
)

for var in self.conditioning_vars:
if var not in self.data.columns:
raise ValueError(
f"Conditioning variable '{var}' not found in DataFrame."
)

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
time_series = self.data.iloc[idx][self.time_series_column]
time_series = torch.tensor(time_series, dtype=torch.float32)

conditioning_vars_dict = {}
for var in self.conditioning_vars:
value = self.data.iloc[idx][var]
conditioning_vars_dict[var] = torch.tensor(value, dtype=torch.long)

return time_series, conditioning_vars_dict
Loading

0 comments on commit 1630e7e

Please sign in to comment.