Skip to content

Commit

Permalink
Merge branch 'master' of github.com:materialsvirtuallab/monty
Browse files Browse the repository at this point in the history
  • Loading branch information
shyuep committed Dec 10, 2024
2 parents 94de03e + b62bfe7 commit 8ca861e
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 13 deletions.
27 changes: 14 additions & 13 deletions src/monty/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from enum import Enum
from hashlib import sha1
from importlib import import_module
from inspect import getfullargspec
from inspect import getfullargspec, isclass
from pathlib import Path
from typing import TYPE_CHECKING
from uuid import UUID, uuid4
Expand Down Expand Up @@ -70,12 +70,12 @@ def _load_redirect(redirect_file) -> dict:
return dict(redirect_dict)


def _check_type(obj, type_str: tuple[str, ...] | str) -> bool:
def _check_type(obj: object, type_str: tuple[str, ...] | str) -> bool:
"""Alternative to isinstance that avoids imports.
Checks whether obj is an instance of the type defined by type_str. This
removes the need to explicitly import type_str. Handles subclasses like
isinstance does. E.g.::
isinstance does. E.g.:
class A:
pass
Expand All @@ -90,21 +90,22 @@ class B(A):
assert isinstance(b, A)
assert not isinstance(a, B)
type_str: str | tuple[str]
Note for future developers: the type_str is not always obvious for an
object. For example, pandas.DataFrame is actually pandas.core.frame.DataFrame.
object. For example, pandas.DataFrame is actually "pandas.core.frame.DataFrame".
To find out the type_str for an object, run type(obj).mro(). This will
list all the types that an object can resolve to in order of generality
(all objects have the builtins.object as the last one).
(all objects have the "builtins.object" as the last one).
"""
type_str = type_str if isinstance(type_str, tuple) else (type_str,)
# I believe this try-except is only necessary for callable types
try:
mro = type(obj).mro()
except TypeError:
# This function is intended as an alternative of "isinstance",
# therefore wouldn't check class
if isclass(obj):
return False
return any(f"{o.__module__}.{o.__name__}" == ts for o in mro for ts in type_str)

type_str = type_str if isinstance(type_str, tuple) else (type_str,)

mro = type(obj).mro()

return any(f"{o.__module__}.{o.__qualname__}" == ts for o in mro for ts in type_str)


class MSONable:
Expand Down
129 changes: 129 additions & 0 deletions tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
MontyDecoder,
MontyEncoder,
MSONable,
_check_type,
_load_redirect,
jsanitize,
load,
Expand Down Expand Up @@ -1081,6 +1082,134 @@ def test_enum(self):
na2 = EnumAsDict.from_dict(d_)
assert na2 == na1


class TestCheckType:
def test_check_subclass(self):
class A:
pass

class B(A):
pass

a, b = A(), B()

class_name_A = f"{type(a).__module__}.{type(a).__qualname__}"
class_name_B = f"{type(b).__module__}.{type(b).__qualname__}"

# a is an instance of A, but not B
assert _check_type(a, class_name_A)
assert isinstance(a, A)
assert not _check_type(a, class_name_B)
assert not isinstance(a, B)

# b is an instance of both B and A
assert _check_type(b, class_name_B)
assert isinstance(b, B)
assert _check_type(b, class_name_A)
assert isinstance(b, A)

def test_check_class(self):
"""This should not work for classes."""

class A:
pass

class B(A):
pass

class_name_A = f"{A.__module__}.{A.__qualname__}"
class_name_B = f"{B.__module__}.{B.__qualname__}"

# Test class behavior (should return False, like isinstance does)
assert not _check_type(A, class_name_A)
assert not _check_type(B, class_name_B)
assert not _check_type(B, class_name_A)

def test_callable(self):
# Test function
def my_function():
pass

callable_class_name = (
f"{type(my_function).__module__}.{type(my_function).__qualname__}"
)

assert _check_type(my_function, callable_class_name), callable_class_name
assert isinstance(my_function, type(my_function))

# Test callable class
class MyCallableClass:
def __call__(self):
pass

callable_instance = MyCallableClass()
assert callable(callable_instance)

callable_class_instance_name = f"{type(callable_instance).__module__}.{type(callable_instance).__qualname__}"

assert _check_type(
callable_instance, callable_class_instance_name
), callable_class_instance_name
assert isinstance(callable_instance, MyCallableClass)

def test_numpy(self):
# Test NumPy array
arr = np.array([1, 2, 3])

assert _check_type(arr, "numpy.ndarray")
assert isinstance(arr, np.ndarray)

# Test NumPy generic
scalar = np.float64(3.14)

assert _check_type(scalar, "numpy.generic")
assert isinstance(scalar, np.generic)

@pytest.mark.skipif(pd is None, reason="pandas is not installed")
def test_pandas(self):
# Test pandas DataFrame
df = pd.DataFrame({"a": [1, 2, 3]})

assert _check_type(df, "pandas.core.frame.DataFrame")
assert isinstance(df, pd.DataFrame)

assert _check_type(df, "pandas.core.base.PandasObject")
assert isinstance(df, pd.core.base.PandasObject)

# Test pandas Series
series = pd.Series([1, 2, 3])

assert _check_type(series, "pandas.core.series.Series")
assert isinstance(series, pd.Series)

assert _check_type(series, "pandas.core.base.PandasObject")
assert isinstance(series, pd.core.base.PandasObject)

@pytest.mark.skipif(torch is None, reason="torch is not installed")
def test_torch(self):
tensor = torch.tensor([1, 2, 3])

assert _check_type(tensor, "torch.Tensor")
assert isinstance(tensor, torch.Tensor)

@pytest.mark.skipif(pydantic is None, reason="pydantic is not installed")
def test_pydantic(self):
class MyModel(pydantic.BaseModel):
name: str

model_instance = MyModel(name="Alice")

assert _check_type(model_instance, "pydantic.main.BaseModel")
assert isinstance(model_instance, pydantic.BaseModel)

@pytest.mark.skipif(pint is None, reason="pint is not installed")
def test_pint(self):
ureg = pint.UnitRegistry()
qty = 3 * ureg.meter

assert _check_type(qty, "pint.registry.Quantity")
assert isinstance(qty, pint.Quantity)

@pytest.mark.skipif(ObjectId is None, reason="bson not present")
def test_extended_json(self):
from bson import json_util
Expand Down

0 comments on commit 8ca861e

Please sign in to comment.