diff --git a/tests/test_generate.py b/tests/test_generate.py index 45bfb82..333189e 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -3,7 +3,7 @@ from unittest import TestCase import torch -from transformers import AutoTokenizer, GPT2LMHeadModel, M2M100ForConditionalGeneration, GenerationConfig +from transformers import AutoTokenizer, GPT2LMHeadModel, M2M100ForConditionalGeneration, GenerationConfig, set_seed from transformers.generation import SampleDecoderOnlyOutput, SampleEncoderDecoderOutput from mbr import MBR, MBRConfig, MBROutput, MetricOutput @@ -13,6 +13,7 @@ class DecoderOnlyTestCase(TestCase): def setUp(self): + set_seed(42) self.model = MBR(GPT2LMHeadModel).from_pretrained("distilgpt2").eval() self.tokenizer = AutoTokenizer.from_pretrained("distilgpt2") @@ -189,6 +190,7 @@ def test_references_config(self): class EncoderDecoderTestCase(TestCase): def setUp(self): + set_seed(42) self.model = MBR(M2M100ForConditionalGeneration).from_pretrained("alirezamsh/small100").eval() self.tokenizer = AutoTokenizer.from_pretrained("alirezamsh/small100") self.tokenizer.tgt_lang = "fr" diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 16e63d4..f1bcc31 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -2,7 +2,7 @@ import unittest from unittest import TestCase -from transformers import AutoTokenizer, pipeline, GPT2LMHeadModel, M2M100ForConditionalGeneration +from transformers import AutoTokenizer, pipeline, GPT2LMHeadModel, M2M100ForConditionalGeneration, set_seed from mbr import MBRConfig from mbr import MBR @@ -11,6 +11,7 @@ class TextGenerationTestCase(TestCase): def setUp(self): + set_seed(42) self.model = MBR(GPT2LMHeadModel).from_pretrained("distilgpt2").eval() self.tokenizer = AutoTokenizer.from_pretrained("distilgpt2") self.pipeline = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer) @@ -32,6 +33,7 @@ def test_pipeline(self): class TranslationTestCase(TestCase): def setUp(self): + set_seed(42) self.model = MBR(M2M100ForConditionalGeneration).from_pretrained("alirezamsh/small100").eval() self.tokenizer = AutoTokenizer.from_pretrained("alirezamsh/small100") self.pipeline = pipeline("translation_en_to_fr", model=self.model, tokenizer=self.tokenizer)