From ce688db0607bd1be88399990524487ffba1eebe2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=90=D0=BD=D0=B4=D1=80=D0=B5=D0=B9=20=D0=9A=D0=BE=D0=B7?= =?UTF-8?q?=D0=BB=D1=8E=D0=BA?= Date: Wed, 30 Mar 2022 00:32:46 +0500 Subject: [PATCH] =?UTF-8?q?=D1=82=D0=B5=D1=81=D1=82=20=D0=BD=D0=B0=20?= =?UTF-8?q?=D0=B7=D0=B0=D0=B3=D1=80=D1=83=D0=B7=D0=BA=D1=83=20sql=20=D0=B4?= =?UTF-8?q?=D0=B0=D1=82=D0=B0=D1=81=D0=B5=D1=82=D0=B0=20=D0=B8=D0=B7=20?= =?UTF-8?q?=D0=BA=D0=BE=D0=BD=D1=84=D0=B8=D0=B3=D0=B0=20+=20=D1=80=D0=B0?= =?UTF-8?q?=D0=B7=D0=B4=D0=B5=D0=BB=D0=B5=D0=BD=D0=B8=D0=B5=20=D1=82=D0=B5?= =?UTF-8?q?=D1=81=D1=82=D0=BE=D0=B2=D1=8B=D1=85=20=D1=84=D0=B8=D0=BA=D1=81?= =?UTF-8?q?=D1=82=D1=83=D1=80=20=D0=BD=D0=B0=20in=20memory=20=D0=B1=D0=B0?= =?UTF-8?q?=D0=B7=D1=83=20=D0=B4=D0=B0=D1=82=D0=B0=D1=81=D0=B5=D1=82=D0=B0?= =?UTF-8?q?=20=D0=B8=20=D0=B1=D0=B0=D0=B7=D1=83=20=D0=B4=D0=B0=D1=82=D0=B0?= =?UTF-8?q?=D1=81=D0=B5=D1=82=D0=B0=20=D0=B2=20=D1=84=D0=B0=D0=B9=D0=BB?= =?UTF-8?q?=D0=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test/test_sql_dataset.py | 78 +++++++++++++------ 1 file changed, 55 insertions(+), 23 deletions(-) diff --git a/para_tri_dataset/paraphrase_dataset/para_phraser_plus/test/test_sql_dataset.py b/para_tri_dataset/paraphrase_dataset/para_phraser_plus/test/test_sql_dataset.py index af1884c..79d61d1 100644 --- a/para_tri_dataset/paraphrase_dataset/para_phraser_plus/test/test_sql_dataset.py +++ b/para_tri_dataset/paraphrase_dataset/para_phraser_plus/test/test_sql_dataset.py @@ -4,7 +4,9 @@ from sqlalchemy.orm import sessionmaker from para_tri_dataset.alchemy_utils import Database +from para_tri_dataset.config import Config +from para_tri_dataset.paraphrase_dataset import get_dataset_from_config from para_tri_dataset.paraphrase_dataset.para_phraser_plus import ParaPhraserPlusSQLDataset from para_tri_dataset.paraphrase_dataset.para_phraser_plus.base import ParaPhraserPlusPhrase @@ -29,67 +31,97 @@ def paraphrases(phrases): ] -@pytest.fixture -def database(paraphrases) -> Database: - engine = create_engine("sqlite:///") - Session = sessionmaker(engine, expire_on_commit=False) - db = Database(engine, Session) +def create_database(db: Database, dataset_paraphrases): db.create_all(Base) with db.session_scope() as session: - for group_id, phrases in enumerate(paraphrases): + for group_id, phrases in enumerate(dataset_paraphrases): session.add_all([ParaphraserPlusDataset(id=p.id, text=p.text, group_id=group_id) for p in phrases]) session.commit() + +@pytest.fixture +def dataset_config(tmp_path, paraphrases): + dataset_filepath = tmp_path / 'test-paraphraser-plus.sqlite' + db_url = f"sqlite:///{dataset_filepath.absolute()}" + + engine = create_engine(db_url) + Session = sessionmaker(engine, expire_on_commit=False) + + db = Database(engine, Session) + create_database(db, paraphrases) + + cfg = Config(name="paraphrase_plus_sql", data={'db_url': db_url}, nested_configs=[]) + yield cfg + + dataset_filepath.unlink() + + +@pytest.fixture +def in_memory_database(paraphrases) -> Database: + engine = create_engine("sqlite:///") + Session = sessionmaker(engine, expire_on_commit=False) + db = Database(engine, Session) + + create_database(db, paraphrases) return db -def test_get_phrase_by_id(database, phrases): - dataset = ParaPhraserPlusSQLDataset(database) +@pytest.fixture +def in_memory_dataset(in_memory_database) -> ParaPhraserPlusSQLDataset: + return ParaPhraserPlusSQLDataset(in_memory_database) + + +def test_load_from_config(dataset_config): + _ = ParaPhraserPlusSQLDataset.from_config(dataset_config) + _ = get_dataset_from_config(dataset_config) + + with pytest.raises(ValueError): + dataset_config.name = "foo" + _ = get_dataset_from_config(dataset_config) + + +def test_get_phrase_by_id(in_memory_dataset, phrases): for phrase in phrases: - dataset_phrase = dataset.get_phrase_by_id(phrase.id) + dataset_phrase = in_memory_dataset.get_phrase_by_id(phrase.id) assert phrase == dataset_phrase -def test_iterate_phrases(database, phrases): - dataset = ParaPhraserPlusSQLDataset(database) +def test_iterate_phrases(in_memory_dataset, phrases): - for orig_phrase, dataset_phrase in zip(phrases, dataset.iterate_phrases()): + for orig_phrase, dataset_phrase in zip(phrases, in_memory_dataset.iterate_phrases()): assert orig_phrase == dataset_phrase - for orig_phrase, dataset_phrase in zip(phrases[1:], dataset.iterate_phrases(offset=1)): + for orig_phrase, dataset_phrase in zip(phrases[1:], in_memory_dataset.iterate_phrases(offset=1)): assert orig_phrase == dataset_phrase -def test_iterate_phrases_ids(database, phrases): - dataset = ParaPhraserPlusSQLDataset(database) +def test_iterate_phrases_ids(in_memory_dataset, phrases): - for orig_phrase, phrase_id in zip(phrases, dataset.iterate_phrases_id()): + for orig_phrase, phrase_id in zip(phrases, in_memory_dataset.iterate_phrases_id()): assert orig_phrase.id == phrase_id - for orig_phrase, phrase_id in zip(phrases[1:], dataset.iterate_phrases_id(offset=1)): + for orig_phrase, phrase_id in zip(phrases[1:], in_memory_dataset.iterate_phrases_id(offset=1)): assert orig_phrase.id == phrase_id -def test_get_paraphrases(database, paraphrases): - dataset = ParaPhraserPlusSQLDataset(database) +def test_get_paraphrases(in_memory_dataset, paraphrases): for phrases in paraphrases: for i in range(len(phrases)): phrase, phrase_paraphrases = phrases[i], [p for j, p in enumerate(phrases) if j != i] - dataset_paraphrases = dataset.get_paraphrases(phrase) + dataset_paraphrases = in_memory_dataset.get_paraphrases(phrase) assert phrase_paraphrases == dataset_paraphrases -def test_get_paraphrases_id(database, paraphrases): - dataset = ParaPhraserPlusSQLDataset(database) +def test_get_paraphrases_id(in_memory_dataset, paraphrases): for phrases in paraphrases: for i in range(len(phrases)): phrase_id, phrase_paraphrases_id = phrases[i].id, [p.id for j, p in enumerate(phrases) if j != i] - dataset_paraphrases_id = dataset.get_paraphrases_id(phrase_id) + dataset_paraphrases_id = in_memory_dataset.get_paraphrases_id(phrase_id) assert phrase_paraphrases_id == dataset_paraphrases_id