Skip to content

Commit

Permalink
add train
Browse files Browse the repository at this point in the history
opimize.py:
	add get_loss, train, evaluate
adjust configs
  • Loading branch information
armingh2000 committed Jan 30, 2024
1 parent 5d1f864 commit 46ef7f3
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 30 deletions.
8 changes: 4 additions & 4 deletions src/configs/model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@
learning_rate = 0.001
cyclic_loss = "MSELoss"
acyclic_loss = "L1Loss"
cyclic_loss_weight = 0.7
acyclic_loss_weight = 0.3
cyclic_loss_weight = 0.3
acyclic_loss_weight = 0.7
optimizer = "Adam"

# Dataset
train_dataset_path = project_root / "data/historical/train/train_dataset.pt"
val_dataset_path = project_root / "data/historical/train/val_dataset.pt"
test_dataset_path = project_root / "data/historical/train/test_dataset.pt"
train_split = 0.7
val_split = 0.15
train_split = 0.001
val_split = 0.0001
torch_seed = 57885161 # prime number
8 changes: 3 additions & 5 deletions src/model/train/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,11 @@
# Setting torch seed
torch.manual_seed(configs.torch_seed)

SHD = StockHistoryDataset(metadata, spark, logger)
train_loader, val_loader, test_loader = prepare_loaders(metadata, spark, logger)

# train_loader, val_loader, test_loader = prepare_loaders(metadata, spark, logger)
model = StockLSTM()

# model = StockLSTM()

# train(model, train_loader, val_loader, logger)
train(model, train_loader, val_loader, logger)

# revert std streams
revert_streams()
64 changes: 43 additions & 21 deletions src/model/train/optimize.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,68 @@
import torch
import torch.nn as nn
import src.configs as configs
from tqdm import tqdm

cyclic_loss_function = getattr(nn, configs.cyclic_loss)()
acyclic_loss_function = getattr(nn, configs.acyclic_loss)()
cyclic_loss_weight = configs.cyclic_loss_weight
acyclic_loss_weight = configs.acyclic_loss_weight


def get_loss(pred, target):
cyclic_features = pred[:, :5]
acyclic_features = pred[:, 5:]

cyclic_target = target[:, :5]
acyclic_target = target[:, 5:]

cyclic_loss = cyclic_loss_function(cyclic_features, cyclic_target)
acyclic_loss = acyclic_loss_function(acyclic_features, acyclic_target)

return cyclic_loss * cyclic_loss_weight + acyclic_loss * acyclic_loss_weight


def train(model, train_loader, val_loader, logger):
cyclic_loss_function = getattr(nn, configs.cyclic_loss)()
acyclic_loss_function = getattr(nn, configs.acyclic_loss)()
cyclic_loss_weight = configs.cyclic_loss_weight
acyclic_loss_weight = configs.acyclic_loss_weight
optimizer = getattr(torch.optim, configs.optimizer)(
model.parameters(), lr=configs.learning_rate
)

for epoch in range(configs.epochs):
logger.info(f"Starting epoch {epoch + 1}/{configs.epochs}")
model.train()
for sequences, targets in train_loader:
for iteration, data in tqdm(
enumerate(train_loader, 0), unit="batch", total=len(train_loader)
):
sequences, targets = data
optimizer.zero_grad()
y_pred = model(sequences)
loss = loss_function(y_pred, targets)
loss = get_loss(y_pred, targets)
loss.backward()
optimizer.step()

model.eval()
val_loss = 0.0
with torch.no_grad():
for sequences, targets in val_loader:
y_pred = model(sequences)
val_loss += loss_function(y_pred, targets).item()
if (iteration + 1) % 5 == 0:
evaluate(model, val_loader, logger)

evaluate(model, val_loader, logger)
logger.info(f"Epoch {epoch + 1}/{configs.epochs} complete.")

val_loss /= len(val_loader)
print(
f"Epoch {epoch + 1}/{configs.epochs} - Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}"
)
logger.info("Training complete.")

print("Training complete.")

def evaluate(model, val_loader, logger):
logger.info("Evaluating model ...")
model.eval()
val_loss = 0.0
with torch.no_grad():
for sequences, targets in val_loader:
y_pred = model(sequences)
val_loss += get_loss(y_pred, targets).item()

def evaluate(model, test_loader):
pass
val_loss /= len(val_loader)
logger.info(f"Validation Loss: {val_loss:.4f}")


def test_model(model, test_loader):
def test_model(model, test_loader, logger):
model.eval()
test_loss = 0.0
loss_function = nn.MSELoss()
Expand All @@ -51,4 +73,4 @@ def test_model(model, test_loader):
test_loss += loss_function(y_pred, targets).item()

test_loss /= len(test_loader)
print(f"Test Loss: {test_loss:.4f}")
logger.info(f"Test Loss: {test_loss:.4f}")

0 comments on commit 46ef7f3

Please sign in to comment.