From c3c1b2f802b7c773c252237ac553c4a9e12cd10c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=90=D0=BD=D0=B4=D1=80=D0=B5=D0=B9=20=D0=9A=D0=BE=D0=B7?= =?UTF-8?q?=D0=BB=D1=8E=D0=BA?= Date: Tue, 29 Mar 2022 01:22:17 +0500 Subject: [PATCH] refactor: radon + black + flake8 --- .../paraphrase_dataset/get_dataset.py | 6 ++-- .../para_phraser_plus/file_dataset.py | 2 +- .../para_phraser_plus/sql_dataset.py | 32 ++++++++----------- 3 files changed, 18 insertions(+), 22 deletions(-) diff --git a/para_tri_dataset/paraphrase_dataset/get_dataset.py b/para_tri_dataset/paraphrase_dataset/get_dataset.py index 22578ce..6b7ce47 100644 --- a/para_tri_dataset/paraphrase_dataset/get_dataset.py +++ b/para_tri_dataset/paraphrase_dataset/get_dataset.py @@ -5,8 +5,10 @@ from para_tri_dataset.config import Config -DATASET_NAMES_MAPPING = {"paraphrase_plus_file": ParaPhraserPlusFileDataset, - "paraphrase_plus_sql": ParaPhraserPlusSQLDataset} +DATASET_NAMES_MAPPING = { + "paraphrase_plus_file": ParaPhraserPlusFileDataset, + "paraphrase_plus_sql": ParaPhraserPlusSQLDataset, +} def get_dataset_from_config(cfg: Config): diff --git a/para_tri_dataset/paraphrase_dataset/para_phraser_plus/file_dataset.py b/para_tri_dataset/paraphrase_dataset/para_phraser_plus/file_dataset.py index 28dbfec..e17d6e2 100644 --- a/para_tri_dataset/paraphrase_dataset/para_phraser_plus/file_dataset.py +++ b/para_tri_dataset/paraphrase_dataset/para_phraser_plus/file_dataset.py @@ -5,7 +5,7 @@ import json import os import zipfile -from typing import Tuple, TypedDict, List, Dict, Generator, Any, Sequence +from typing import Tuple, TypedDict, List, Dict, Generator, Sequence from para_tri_dataset.paraphrase_dataset.base import ParaphraseDataset from para_tri_dataset.config import Config diff --git a/para_tri_dataset/paraphrase_dataset/para_phraser_plus/sql_dataset.py b/para_tri_dataset/paraphrase_dataset/para_phraser_plus/sql_dataset.py index 43513ce..b152107 100644 --- a/para_tri_dataset/paraphrase_dataset/para_phraser_plus/sql_dataset.py +++ b/para_tri_dataset/paraphrase_dataset/para_phraser_plus/sql_dataset.py @@ -31,19 +31,16 @@ def size(self) -> int: return session.query(ParaphraserPlusDataset).count() def _scroll_rows(self, session, offset: int, fields) -> Query: - return (session - .query(*fields) - .order_by(ParaphraserPlusDataset.id) - .limit(self.scroll_size + 1) - .offset(offset)) + return session.query(*fields).order_by(ParaphraserPlusDataset.id).limit(self.scroll_size + 1).offset(offset) def get_phrase_by_id(self, phrase_id: int) -> ParaPhraserPlusPhrase: with self.storage_db.session_scope() as session: try: - row = (session - .query(ParaphraserPlusDataset.id, ParaphraserPlusDataset.text) - .where(ParaphraserPlusDataset.id == phrase_id) - ).one() + row = ( + session.query(ParaphraserPlusDataset.id, ParaphraserPlusDataset.text).where( + ParaphraserPlusDataset.id == phrase_id + ) + ).one() except sqlalchemy.exc.NoResultFound as ex: raise ValueError(f"not found phrase by id {phrase_id}") from ex @@ -51,15 +48,13 @@ def get_phrase_by_id(self, phrase_id: int) -> ParaPhraserPlusPhrase: @staticmethod def _get_paraphrases_query(phrase_id: int, fields) -> Query: - group_id_sub = (select(ParaphraserPlusDataset.group_id) - .where(ParaphraserPlusDataset.id == phrase_id) - ).scalar_subquery() + group_id_sub = ( + select(ParaphraserPlusDataset.group_id).where(ParaphraserPlusDataset.id == phrase_id) + ).scalar_subquery() - result_query = (select(*fields) - .where(and_(ParaphraserPlusDataset.group_id == group_id_sub, - ParaphraserPlusDataset.id != phrase_id) - ) - ) + result_query = select(*fields).where( + and_(ParaphraserPlusDataset.group_id == group_id_sub, ParaphraserPlusDataset.id != phrase_id) + ) return result_query @@ -104,11 +99,10 @@ def iterate_phrases(self, start_offset: int = 0) -> Generator[ParaPhraserPlusPhr fields = [ParaphraserPlusDataset.id, ParaphraserPlusDataset.text] rows = self._scroll_rows(session, offset, fields).all() - for row in rows[:self.scroll_size]: + for row in rows[: self.scroll_size]: yield ParaPhraserPlusPhrase(id=row.id, text=row.text) if self.scroll_size + 1 > len(rows): break offset += self.scroll_size -