diff --git a/pipelines/trainer.py b/pipelines/trainer.py index 8860196..fda020a 100644 --- a/pipelines/trainer.py +++ b/pipelines/trainer.py @@ -16,7 +16,6 @@ class TrainingParams(TypedDict): epochs: int - batch_size: int loss_criterion: nn.Module accuracy_criterion: nn.Module optimizer: nn.Module diff --git a/tests/pipelines/test_experimenter.py b/tests/pipelines/test_experimenter.py index 6d297a2..c03b4b5 100644 --- a/tests/pipelines/test_experimenter.py +++ b/tests/pipelines/test_experimenter.py @@ -16,7 +16,6 @@ def test_run(): model = TestModel() training_params: TrainingParams = { "epochs": 1, - "batch_size": 1, "loss_criterion": nn.MSELoss(), "accuracy_criterion": nn.L1Loss(), "optimizer": Adam(model.parameters(), lr=0.0005),