Skip to content

Commit

Permalink
update config names in tests
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
  • Loading branch information
akoumpa committed Sep 27, 2024
1 parent b57fd6c commit b7c0bfa
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/collections/llm/gpt/model/test_mistral.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch.nn.functional as F

from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralNeMo2407Config12B, MistralNeMo2407Config123B
from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralNeMoConfig12B, MistralNeMoConfig123B


def test_mistral_config7b():
Expand All @@ -25,7 +25,7 @@ def test_mistral_config7b():


def test_mistral_nemo_config_12b():
config = MistralNeMo2407Config12B()
config = MistralNeMoConfig12B()
assert config.normalization == "RMSNorm"
assert config.activation_func == F.silu
assert config.position_embedding_type == "rope"
Expand All @@ -49,7 +49,7 @@ def test_mistral_nemo_config_12b():


def test_mistral_nemo_config_123b():
config = MistralNeMo2407Config123B()
config = MistralNeMoConfig123B()
assert config.normalization == "RMSNorm"
assert config.activation_func == F.silu
assert config.position_embedding_type == "rope"
Expand Down

0 comments on commit b7c0bfa

Please sign in to comment.