Skip to content

Commit

Permalink
black/lint
Browse files Browse the repository at this point in the history
  • Loading branch information
zhu0619 committed Aug 12, 2024
1 parent b796eb7 commit 08cea4e
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 12 deletions.
24 changes: 16 additions & 8 deletions molfeat/plugins/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,15 @@ def CalculatorFactory(
entry_point_name: str,
load: Literal[True] = True,
entry_point_group: Optional[str] = None,
) -> Union[Type["SerializableCalculator"], Callable]: ...
) -> Union[Type["SerializableCalculator"], Callable]:
...


@overload
def CalculatorFactory(
entry_point_name: str, load: Literal[False], entry_point_group: Optional[str] = None
) -> EntryPoint: ...
) -> EntryPoint:
...


def CalculatorFactory(
Expand Down Expand Up @@ -132,13 +134,15 @@ def TransformerFactory(
entry_point_name: str,
load: Literal[True] = True,
entry_point_group: Optional[str] = None,
) -> Union[Type["MoleculeTransformer"], Callable]: ...
) -> Union[Type["MoleculeTransformer"], Callable]:
...


@overload
def TransformerFactory(
entry_point_name: str, load: Literal[False], entry_point_group: Optional[str] = None
) -> EntryPoint: ...
) -> EntryPoint:
...


def TransformerFactory(
Expand Down Expand Up @@ -184,11 +188,13 @@ def PretrainedTransformerFactory(
entry_point_name: str,
load: Literal[True] = True,
entry_point_group: Optional[str] = None,
) -> Union[Type["PretrainedMolTransformer"], Callable]: ...
) -> Union[Type["PretrainedMolTransformer"], Callable]:
...


@overload
def PretrainedTransformerFactory(entry_point_name: str, load: Literal[False]) -> EntryPoint: ...
def PretrainedTransformerFactory(entry_point_name: str, load: Literal[False]) -> EntryPoint:
...


def PretrainedTransformerFactory(
Expand Down Expand Up @@ -233,13 +239,15 @@ def DefaultFactory(
entry_point_name: str,
load: Literal[True] = True,
entry_point_group: str = None,
) -> Union[Type["PretrainedMolTransformer"], Callable]: ...
) -> Union[Type["PretrainedMolTransformer"], Callable]:
...


@overload
def DefaultFactory(
entry_point_name: str, load: Literal[False], entry_point_group: str = None
) -> EntryPoint: ...
) -> EntryPoint:
...


def DefaultFactory(
Expand Down
2 changes: 1 addition & 1 deletion molfeat/trans/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def __str__(self):
return self.__repr__()

def __eq__(self, other):
same_type = type(self) == type(other)
same_type = isinstance(self, type(other))
return same_type and all(
[getattr(other, k) == v for k, v in self.get_params() if not callable(v)]
)
Expand Down
9 changes: 6 additions & 3 deletions molfeat/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,12 @@ def __call__(
self._sync_cache()
return self.fetch(mols)

def clear(self, *args, **kwargs): ...
def clear(self, *args, **kwargs):
...

@abc.abstractmethod
def update(self, new_cache: Mapping[Any, Any]): ...
def update(self, new_cache: Mapping[Any, Any]):
...

def get(self, key, default: Optional[Any] = None):
"""Get the cached value for a specific key
Expand All @@ -238,7 +240,8 @@ def to_dict(self):
"""Convert current cache to a dictionary"""
return dict(self.items())

def _sync_cache(self): ...
def _sync_cache(self):
...

def fetch(
self,
Expand Down

0 comments on commit 08cea4e

Please sign in to comment.