From 63ce10cdd9ca723247276bc276705ff542fcb4da Mon Sep 17 00:00:00 2001 From: Ben van Werkhoven Date: Fri, 28 Jun 2024 11:24:06 +0200 Subject: [PATCH] fix minor codestyle issues --- kernel_tuner/cache/cache.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/kernel_tuner/cache/cache.py b/kernel_tuner/cache/cache.py index 6045b9b1..c7556f57 100644 --- a/kernel_tuner/cache/cache.py +++ b/kernel_tuner/cache/cache.py @@ -7,22 +7,20 @@ import json from collections import OrderedDict -from collections.abc import Mapping from datetime import datetime from functools import cached_property from functools import lru_cache as cache from os import PathLike from pathlib import Path -from typing import Any, Dict, Iterable, Iterator, Optional, Tuple, Union, cast +from typing import Any, Optional, Union, cast import jsonschema -import kernel_tuner.util as util import numpy as np +from kernel_tuner import util from semver import Version from .convert import convert_cache from .file import append_cache_line -from .json import CacheFileJSON, CacheLineJSON from .paths import get_schema_path from .versions import LATEST_VERSION, VERSIONS @@ -259,7 +257,7 @@ def get_from_params(self, default=None, **params) -> Union[Cache.Line, list[Cach results = [self[k] for k in line_ids] if not results: return default - elif len(results) == 1: + if len(results) == 1: results = results[0] return results @@ -297,7 +295,7 @@ def __get_line_id_from_tune_params_dict(self, tune_params: dict) -> str: class ReadableLines(Lines): """Cache lines in a read_only cache file.""" - def append(*args, **kwargs): + def append(self, *args, **kwargs): """Method to append lines to cache file, should not happen with read-only cache""" raise ValueError("Attempting to write to read-only cache") @@ -312,7 +310,7 @@ def __init__(self, *args, **kwargs): def __getattr__(self, name): if not name.startswith("_"): return self[name] - return super(Line, self).__getattr__(name) + return super(dict, self).__getattr__(name) def _encode_cache_line(line_id, line): @@ -370,9 +368,9 @@ def default(self, o): return float(o) elif isinstance(o, np.ndarray): return o.tolist() - super().default(o) + return super().default(o) - def iterencode(self, obj, *args, **kwargs): + def iterencode(self, obj, **kwargs): """encode an iterator, ensuring 'cache' is the last entry for encoded dicts""" # ensure key 'cache' is last in any encoded dictionary @@ -381,7 +379,7 @@ def iterencode(self, obj, *args, **kwargs): if "cache" in obj: obj.move_to_end("cache") - yield from super().iterencode(obj, *args, **kwargs) + yield from super().iterencode(obj, **kwargs) def read_cache_file(filename: PathLike): @@ -397,7 +395,7 @@ def read_cache_file(filename: PathLike): try: data = json.load(file) except json.JSONDecodeError as e: - raise InvalidCacheError(filename, "Cache file is not parsable", e) + raise InvalidCacheError(filename, "Cache file is not parsable", e) from e return data