Skip to content

Commit

Permalink
refactor: radon + black + flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
Андрей Козлюк committed Mar 28, 2022
1 parent b48ad8a commit c3c1b2f
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 22 deletions.
6 changes: 4 additions & 2 deletions para_tri_dataset/paraphrase_dataset/get_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

from para_tri_dataset.config import Config

DATASET_NAMES_MAPPING = {"paraphrase_plus_file": ParaPhraserPlusFileDataset,
"paraphrase_plus_sql": ParaPhraserPlusSQLDataset}
DATASET_NAMES_MAPPING = {
"paraphrase_plus_file": ParaPhraserPlusFileDataset,
"paraphrase_plus_sql": ParaPhraserPlusSQLDataset,
}


def get_dataset_from_config(cfg: Config):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
import os
import zipfile
from typing import Tuple, TypedDict, List, Dict, Generator, Any, Sequence
from typing import Tuple, TypedDict, List, Dict, Generator, Sequence

from para_tri_dataset.paraphrase_dataset.base import ParaphraseDataset
from para_tri_dataset.config import Config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,35 +31,30 @@ def size(self) -> int:
return session.query(ParaphraserPlusDataset).count()

def _scroll_rows(self, session, offset: int, fields) -> Query:
return (session
.query(*fields)
.order_by(ParaphraserPlusDataset.id)
.limit(self.scroll_size + 1)
.offset(offset))
return session.query(*fields).order_by(ParaphraserPlusDataset.id).limit(self.scroll_size + 1).offset(offset)

def get_phrase_by_id(self, phrase_id: int) -> ParaPhraserPlusPhrase:
with self.storage_db.session_scope() as session:
try:
row = (session
.query(ParaphraserPlusDataset.id, ParaphraserPlusDataset.text)
.where(ParaphraserPlusDataset.id == phrase_id)
).one()
row = (
session.query(ParaphraserPlusDataset.id, ParaphraserPlusDataset.text).where(
ParaphraserPlusDataset.id == phrase_id
)
).one()
except sqlalchemy.exc.NoResultFound as ex:
raise ValueError(f"not found phrase by id {phrase_id}") from ex

return ParaPhraserPlusPhrase(id=row.id, text=row.text)

@staticmethod
def _get_paraphrases_query(phrase_id: int, fields) -> Query:
group_id_sub = (select(ParaphraserPlusDataset.group_id)
.where(ParaphraserPlusDataset.id == phrase_id)
).scalar_subquery()
group_id_sub = (
select(ParaphraserPlusDataset.group_id).where(ParaphraserPlusDataset.id == phrase_id)
).scalar_subquery()

result_query = (select(*fields)
.where(and_(ParaphraserPlusDataset.group_id == group_id_sub,
ParaphraserPlusDataset.id != phrase_id)
)
)
result_query = select(*fields).where(
and_(ParaphraserPlusDataset.group_id == group_id_sub, ParaphraserPlusDataset.id != phrase_id)
)

return result_query

Expand Down Expand Up @@ -104,11 +99,10 @@ def iterate_phrases(self, start_offset: int = 0) -> Generator[ParaPhraserPlusPhr
fields = [ParaphraserPlusDataset.id, ParaphraserPlusDataset.text]
rows = self._scroll_rows(session, offset, fields).all()

for row in rows[:self.scroll_size]:
for row in rows[: self.scroll_size]:
yield ParaPhraserPlusPhrase(id=row.id, text=row.text)

if self.scroll_size + 1 > len(rows):
break

offset += self.scroll_size

0 comments on commit c3c1b2f

Please sign in to comment.