diff --git a/forge/test/mlir/mnist/training/test_training.py b/forge/test/mlir/mnist/training/test_training.py index c73d70c3d..e820a63f5 100644 --- a/forge/test/mlir/mnist/training/test_training.py +++ b/forge/test/mlir/mnist/training/test_training.py @@ -29,7 +29,7 @@ def test_mnist_training(): # Set training hyperparameters num_epochs = 3 - batch_size = 2048 + batch_size = 1024 learning_rate = 0.001 # Load dataset