diff --git a/torchbenchmark/models/llama/model.py b/torchbenchmark/models/llama/model.py index a01f4cae6a..760179e020 100644 --- a/torchbenchmark/models/llama/model.py +++ b/torchbenchmark/models/llama/model.py @@ -19,7 +19,7 @@ class ModelArgs: multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 norm_eps: float = 1e-5 - max_batch_size: int = 32 # From the paper they use a batch size of 4M for training + max_batch_size: int = 64 # From the paper they use a batch size of 4M for training max_seq_len: int = 1024 device: Optional[str] = None