Skip to content

Commit

Permalink
ParaPhraserPlus датасет в виде sql базы данных
Browse files Browse the repository at this point in the history
  • Loading branch information
Андрей Козлюк committed Mar 28, 2022
1 parent d3b6fb9 commit 468d26a
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 4 deletions.
5 changes: 3 additions & 2 deletions para_tri_dataset/paraphrase_dataset/get_dataset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""
Инициализация датасета из конфига
"""
from para_tri_dataset.paraphrase_dataset.para_phraser_plus import ParaPhraserPlusFileDataset
from para_tri_dataset.paraphrase_dataset.para_phraser_plus import ParaPhraserPlusFileDataset, ParaPhraserPlusSQLDataset

from para_tri_dataset.config import Config

DATASET_NAMES_MAPPING = {"paraphrase_plus_file": ParaPhraserPlusFileDataset}
DATASET_NAMES_MAPPING = {"paraphrase_plus_file": ParaPhraserPlusFileDataset,
"paraphrase_plus_sql": ParaPhraserPlusSQLDataset}


def get_dataset_from_config(cfg: Config):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from para_tri_dataset.paraphrase_dataset.para_phraser_plus.file_dataset import ParaPhraserPlusFileDataset
from para_tri_dataset.paraphrase_dataset.para_phraser_plus.sql_dataset import ParaPhraserPlusSQLDataset

__all__ = ["ParaPhraserPlusFileDataset"]
__all__ = ["ParaPhraserPlusFileDataset", "ParaPhraserPlusSQLDataset"]
108 changes: 107 additions & 1 deletion para_tri_dataset/paraphrase_dataset/para_phraser_plus/sql_dataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,114 @@
"""
Датасет парафраз ParaPhraserPlus, который располагается в базе данных (например sqlite)
"""
from typing import Generator, Sequence

import sqlalchemy.exc
from sqlalchemy import select, and_
from sqlalchemy.orm import Query

from para_tri_dataset.config import Config
from para_tri_dataset.paraphrase_dataset.base import ParaphraseDataset

from para_tri_dataset.alchemy_utils import Database
from para_tri_dataset.paraphrase_dataset.para_phraser_plus.alchemy_models import ParaphraserPlusDataset
from para_tri_dataset.paraphrase_dataset.para_phraser_plus.base import ParaPhraserPlusPhrase


class ParaPhraserPlusSQLDataset(ParaphraseDataset):
pass
def __init__(self, storage_db: Database, scroll_size: int = 100):
self.storage_db = storage_db
self.scroll_size = scroll_size

@classmethod
def from_config(cls, cfg: Config) -> "ParaPhraserPlusSQLDataset":
scroll_size = cfg.get("scroll_size", 100)
database = Database.from_url(cfg.get("db_url"))
return cls(database, scroll_size)

def size(self) -> int:
with self.storage_db.session_scope() as session:
return session.query(ParaphraserPlusDataset).count()

def _scroll_rows(self, session, offset: int, fields) -> Query:
return (session
.query(*fields)
.order_by(ParaphraserPlusDataset.id)
.limit(self.scroll_size + 1)
.offset(offset))

def get_phrase_by_id(self, phrase_id: int) -> ParaPhraserPlusPhrase:
with self.storage_db.session_scope() as session:
try:
row = (session
.query(ParaphraserPlusDataset.id, ParaphraserPlusDataset.text)
.where(ParaphraserPlusDataset.id == phrase_id)
).one()
except sqlalchemy.exc.NoResultFound as ex:
raise ValueError(f"not found phrase by id {phrase_id}") from ex

return ParaPhraserPlusPhrase(id=row.id, text=row.text)

@staticmethod
def _get_paraphrases_query(phrase_id: int, fields) -> Query:
group_id_sub = (select(ParaphraserPlusDataset.group_id)
.where(ParaphraserPlusDataset.id == phrase_id)
).scalar_subquery()

result_query = (select(*fields)
.where(and_(ParaphraserPlusDataset.group_id == group_id_sub,
ParaphraserPlusDataset.id != phrase_id)
)
)

return result_query

def get_paraphrases(self, phrase: ParaPhraserPlusPhrase) -> Sequence[ParaPhraserPlusPhrase]:
with self.storage_db.session_scope() as session:
fields = [ParaphraserPlusDataset.id, ParaphraserPlusDataset.text]
try:
rows = session.execute(self._get_paraphrases_query(phrase.id, fields)).all()
except sqlalchemy.exc.NoResultFound as ex:
raise ValueError(f"not fount phrase by id {phrase.id}") from ex

return [ParaPhraserPlusPhrase(id=r.id, text=r.text) for r in rows]

def get_paraphrases_id(self, phrase_id: int) -> Sequence[int]:
with self.storage_db.session_scope() as session:
fields = [ParaphraserPlusDataset.id]
try:
rows = session.execute(self._get_paraphrases_query(phrase_id, fields)).all()
except sqlalchemy.exc.NoResultFound as ex:
raise ValueError(f"not found phrase by id {phrase_id}") from ex

return [r.id for r in rows]

def iterate_phrases_id(self, start_offset: int = 0) -> Generator[int, None, None]:
offset = start_offset
while True:
with self.storage_db.session_scope() as session:
fields = [ParaphraserPlusDataset.id]
rows = self._scroll_rows(session, offset, fields).all()

yield from (row.id for i, row in enumerate(rows, start=1) if i < self.scroll_size + 1)

if self.scroll_size + 1 > len(rows):
break

offset += self.scroll_size

def iterate_phrases(self, start_offset: int = 0) -> Generator[ParaPhraserPlusPhrase, None, None]:
offset = start_offset
while True:
with self.storage_db.session_scope() as session:
fields = [ParaphraserPlusDataset.id, ParaphraserPlusDataset.text]
rows = self._scroll_rows(session, offset, fields).all()

for row in rows[:self.scroll_size]:
yield ParaPhraserPlusPhrase(id=row.id, text=row.text)

if self.scroll_size + 1 > len(rows):
break

offset += self.scroll_size

0 comments on commit 468d26a

Please sign in to comment.