Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add types annotations for core.interface #3822

Merged
merged 26 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
579083b
tweak type and docstring
DanielYang59 May 8, 2024
853419d
use math.gcd over gcd
DanielYang59 May 12, 2024
07799f7
Merge branch master into add-types-core-interface
DanielYang59 Jun 1, 2024
4484d34
use more specific types for ClassVar
DanielYang59 Jun 3, 2024
ad27bc6
more: use more specific types for ClassVar
DanielYang59 Jun 3, 2024
7f464b0
fix some type errors and comment tweaks
DanielYang59 Jun 3, 2024
dabbe78
fix mypy errors
DanielYang59 Jun 3, 2024
f49eb45
enable types: more type errors to fix
DanielYang59 Jun 3, 2024
71f6cb0
fix type errors
DanielYang59 Jun 3, 2024
e4e6f40
fix type errors
DanielYang59 Jun 3, 2024
2a1ded5
fix mypy errors doesn't show locally
DanielYang59 Jun 3, 2024
ec4ec86
Merge branch 'master' into add-types-core-interface
DanielYang59 Jun 4, 2024
f3cabe8
revert change in test and convert rotation_axis and plane to tuple
DanielYang59 Jun 4, 2024
169551a
cast plane type to tuple
DanielYang59 Jun 4, 2024
76d4260
remove `del` of var name
DanielYang59 Jun 4, 2024
a10af26
add and update new type `Tuple3Ints = tuple[int, int, int]`
DanielYang59 Jun 4, 2024
33c1551
relocate `Tuple4Ints` to `core.interface`
DanielYang59 Jun 4, 2024
c09bc54
relocate `Tuple4Ints`
DanielYang59 Jun 4, 2024
d35cea2
use `Tuple3Floats`
DanielYang59 Jun 4, 2024
8573072
Merge branch 'master' into add-types-core-interface
DanielYang59 Jun 5, 2024
342d07c
Merge branch 'master' into add-types-core-interface
DanielYang59 Jun 6, 2024
369a9a9
revert usage of assert_allclose
DanielYang59 Jun 6, 2024
79e1284
use more meaningful types
DanielYang59 Jun 6, 2024
6b9589a
fix replacement
DanielYang59 Jun 6, 2024
8d8166b
Revert "fix replacement"
DanielYang59 Jun 6, 2024
64924f3
revert type aliases in docstring
DanielYang59 Jun 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 26 additions & 31 deletions pymatgen/analysis/diffraction/tem.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pymatgen.analysis.diffraction.core import AbstractDiffractionPatternCalculator
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.util.string import latexify_spacegroup, unicodeify_spacegroup
from pymatgen.util.typing import Tuple3Ints

if TYPE_CHECKING:
from numpy.typing import NDArray
Expand Down Expand Up @@ -47,7 +48,7 @@ def __init__(
self,
symprec: float | None = None,
voltage: float = 200,
beam_direction: tuple[int, int, int] = (0, 0, 1),
beam_direction: Tuple3Ints = (0, 0, 1),
camera_length: int = 160,
debye_waller_factors: dict[str, float] | None = None,
cs: float = 1,
Expand Down Expand Up @@ -104,9 +105,7 @@ def generate_points(coord_left: int = -10, coord_right: int = 10) -> np.ndarray:
points_matrix = (np.ravel(points[i]) for i in range(3))
return np.vstack(list(points_matrix)).transpose()

def zone_axis_filter(
self, points: list[tuple[int, int, int]] | np.ndarray, laue_zone: int = 0
) -> list[tuple[int, int, int]]:
def zone_axis_filter(self, points: list[Tuple3Ints] | np.ndarray, laue_zone: int = 0) -> list[Tuple3Ints]:
"""Filter out all points that exist within the specified Laue zone according to the zone axis rule.

Args:
Expand All @@ -122,11 +121,11 @@ def zone_axis_filter(
return []
filtered = np.where(np.dot(np.array(self.beam_direction), np.transpose(points)) == laue_zone)
result = points[filtered] # type: ignore
return cast(list[tuple[int, int, int]], [tuple(x) for x in result.tolist()])
return cast(list[Tuple3Ints], [tuple(x) for x in result.tolist()])

def get_interplanar_spacings(
self, structure: Structure, points: list[tuple[int, int, int]] | np.ndarray
) -> dict[tuple[int, int, int], float]:
self, structure: Structure, points: list[Tuple3Ints] | np.ndarray
) -> dict[Tuple3Ints, float]:
"""
Args:
structure (Structure): the input structure.
Expand All @@ -142,9 +141,7 @@ def get_interplanar_spacings(
interplanar_spacings_val = np.array([structure.lattice.d_hkl(x) for x in points_filtered])
return dict(zip(points_filtered, interplanar_spacings_val))

def bragg_angles(
self, interplanar_spacings: dict[tuple[int, int, int], float]
) -> dict[tuple[int, int, int], float]:
def bragg_angles(self, interplanar_spacings: dict[Tuple3Ints, float]) -> dict[Tuple3Ints, float]:
"""Get the Bragg angles for every hkl point passed in (where n = 1).

Args:
Expand All @@ -158,7 +155,7 @@ def bragg_angles(
bragg_angles_val = np.arcsin(self.wavelength_rel() / (2 * interplanar_spacings_val))
return dict(zip(plane, bragg_angles_val))

def get_s2(self, bragg_angles: dict[tuple[int, int, int], float]) -> dict[tuple[int, int, int], float]:
def get_s2(self, bragg_angles: dict[Tuple3Ints, float]) -> dict[Tuple3Ints, float]:
"""
Calculates the s squared parameter (= square of sin theta over lambda) for each hkl plane.

Expand All @@ -175,8 +172,8 @@ def get_s2(self, bragg_angles: dict[tuple[int, int, int], float]) -> dict[tuple[
return dict(zip(plane, s2_val))

def x_ray_factors(
self, structure: Structure, bragg_angles: dict[tuple[int, int, int], float]
) -> dict[str, dict[tuple[int, int, int], float]]:
self, structure: Structure, bragg_angles: dict[Tuple3Ints, float]
) -> dict[str, dict[Tuple3Ints, float]]:
"""
Calculates x-ray factors, which are required to calculate atomic scattering factors. Method partially inspired
by the equivalent process in the xrd module.
Expand Down Expand Up @@ -205,8 +202,8 @@ def x_ray_factors(
return x_ray_factors

def electron_scattering_factors(
self, structure: Structure, bragg_angles: dict[tuple[int, int, int], float]
) -> dict[str, dict[tuple[int, int, int], float]]:
self, structure: Structure, bragg_angles: dict[Tuple3Ints, float]
) -> dict[str, dict[Tuple3Ints, float]]:
"""
Calculates atomic scattering factors for electrons using the Mott-Bethe formula (1st order Born approximation).

Expand All @@ -232,8 +229,8 @@ def electron_scattering_factors(
return electron_scattering_factors

def cell_scattering_factors(
self, structure: Structure, bragg_angles: dict[tuple[int, int, int], float]
) -> dict[tuple[int, int, int], int]:
self, structure: Structure, bragg_angles: dict[Tuple3Ints, float]
) -> dict[Tuple3Ints, int]:
"""
Calculates the scattering factor for the whole cell.

Expand All @@ -258,9 +255,7 @@ def cell_scattering_factors(
scattering_factor_curr = 0
return cell_scattering_factors

def cell_intensity(
self, structure: Structure, bragg_angles: dict[tuple[int, int, int], float]
) -> dict[tuple[int, int, int], float]:
def cell_intensity(self, structure: Structure, bragg_angles: dict[Tuple3Ints, float]) -> dict[Tuple3Ints, float]:
"""
Calculates cell intensity for each hkl plane. For simplicity's sake, take I = |F|**2.

Expand Down Expand Up @@ -317,8 +312,8 @@ def get_pattern(
return pd.DataFrame(rows, columns=field_names)

def normalized_cell_intensity(
self, structure: Structure, bragg_angles: dict[tuple[int, int, int], float]
) -> dict[tuple[int, int, int], float]:
self, structure: Structure, bragg_angles: dict[Tuple3Ints, float]
) -> dict[Tuple3Ints, float]:
"""
Normalizes the cell_intensity dict to 1, for use in plotting.

Expand All @@ -340,8 +335,8 @@ def normalized_cell_intensity(
def is_parallel(
self,
structure: Structure,
plane: tuple[int, int, int],
other_plane: tuple[int, int, int],
plane: Tuple3Ints,
other_plane: Tuple3Ints,
) -> bool:
"""
Checks if two hkl planes are parallel in reciprocal space.
Expand All @@ -357,7 +352,7 @@ def is_parallel(
phi = self.get_interplanar_angle(structure, plane, other_plane)
return phi in (180, 0) or np.isnan(phi)

def get_first_point(self, structure: Structure, points: list) -> dict[tuple[int, int, int], float]:
def get_first_point(self, structure: Structure, points: list) -> dict[Tuple3Ints, float]:
"""Get the first point to be plotted in the 2D DP, corresponding to maximum d/minimum R.

Args:
Expand All @@ -378,7 +373,7 @@ def get_first_point(self, structure: Structure, points: list) -> dict[tuple[int,
return {max_d_plane: max_d}

@staticmethod
def get_interplanar_angle(structure: Structure, p1: tuple[int, int, int], p2: tuple[int, int, int]) -> float:
def get_interplanar_angle(structure: Structure, p1: Tuple3Ints, p2: Tuple3Ints) -> float:
"""Get the interplanar angle (in degrees) between the normal of two crystal planes.
Formulas from International Tables for Crystallography Volume C pp. 2-9.

Expand Down Expand Up @@ -432,9 +427,9 @@ def get_interplanar_angle(structure: Structure, p1: tuple[int, int, int], p2: tu

@staticmethod
def get_plot_coeffs(
p1: tuple[int, int, int],
p2: tuple[int, int, int],
p3: tuple[int, int, int],
p1: Tuple3Ints,
p2: Tuple3Ints,
p3: Tuple3Ints,
) -> np.ndarray:
"""
Calculates coefficients of the vector addition required to generate positions for each DP point
Expand All @@ -454,7 +449,7 @@ def get_plot_coeffs(
x = np.dot(a_pinv, b)
return np.ravel(x)

def get_positions(self, structure: Structure, points: list) -> dict[tuple[int, int, int], np.ndarray]:
def get_positions(self, structure: Structure, points: list) -> dict[Tuple3Ints, np.ndarray]:
"""
Calculates all the positions of each hkl point in the 2D diffraction pattern by vector addition.
Distance in centimeters.
Expand Down Expand Up @@ -524,7 +519,7 @@ def tem_dots(self, structure: Structure, points) -> list:

class dot(NamedTuple):
position: NDArray
hkl: tuple[int, int, int]
hkl: Tuple3Ints
intensity: float
film_radius: float
d_spacing: float
Expand Down
9 changes: 5 additions & 4 deletions pymatgen/analysis/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

from pymatgen.analysis.local_env import NearNeighbors
from pymatgen.core import Species
from pymatgen.util.typing import Tuple3Ints


logger = logging.getLogger(__name__)
Expand All @@ -58,7 +59,7 @@

class ConnectedSite(NamedTuple):
site: PeriodicSite
jimage: tuple[int, int, int]
jimage: Tuple3Ints
index: Any # TODO: use more specific type
weight: float
dist: float
Expand Down Expand Up @@ -338,8 +339,8 @@ def add_edge(
self,
from_index: int,
to_index: int,
from_jimage: tuple[int, int, int] = (0, 0, 0),
to_jimage: tuple[int, int, int] | None = None,
from_jimage: Tuple3Ints = (0, 0, 0),
to_jimage: Tuple3Ints | None = None,
weight: float | None = None,
warn_duplicates: bool = True,
edge_properties: dict | None = None,
Expand Down Expand Up @@ -756,7 +757,7 @@ def map_indices(grp: Molecule) -> dict[int, int]:
warn_duplicates=False,
)

def get_connected_sites(self, n: int, jimage: tuple[int, int, int] = (0, 0, 0)) -> list[ConnectedSite]:
def get_connected_sites(self, n: int, jimage: Tuple3Ints = (0, 0, 0)) -> list[ConnectedSite]:
"""Get a named tuple of neighbors of site n:
periodic_site, jimage, index, weight.
Index is the index of the corresponding site
Expand Down
5 changes: 3 additions & 2 deletions pymatgen/analysis/interfaces/coherent_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from collections.abc import Iterator, Sequence

from pymatgen.core import Structure
from pymatgen.util.typing import Tuple3Ints


class CoherentInterfaceBuilder:
Expand All @@ -30,8 +31,8 @@ def __init__(
self,
substrate_structure: Structure,
film_structure: Structure,
film_miller: tuple[int, int, int],
substrate_miller: tuple[int, int, int],
film_miller: Tuple3Ints,
substrate_miller: Tuple3Ints,
zslgen: ZSLGenerator | None = None,
):
"""
Expand Down
5 changes: 3 additions & 2 deletions pymatgen/analysis/interfaces/substrate_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing_extensions import Self

from pymatgen.core import Structure
from pymatgen.util.typing import Tuple3Ints


@dataclass
Expand All @@ -24,8 +25,8 @@ class SubstrateMatch(ZSLMatch):
energy if provided, and the elastic energy.
"""

film_miller: tuple[int, int, int]
substrate_miller: tuple[int, int, int]
film_miller: Tuple3Ints
substrate_miller: Tuple3Ints
strain: Strain
von_mises_strain: float
ground_state_energy: float
Expand Down
3 changes: 2 additions & 1 deletion pymatgen/analysis/local_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from typing_extensions import Self

from pymatgen.core.composition import SpeciesLike
from pymatgen.util.typing import Tuple3Ints


__author__ = "Shyue Ping Ong, Geoffroy Hautier, Sai Jayaraman, "
Expand Down Expand Up @@ -540,7 +541,7 @@ def _get_nn_shell_info(
return list(all_sites.values())

@staticmethod
def _get_image(structure: Structure, site: Site) -> tuple[int, int, int]:
def _get_image(structure: Structure, site: Site) -> Tuple3Ints:
"""Private convenience method for get_nn_info,
gives lattice image from provided PeriodicSite and Structure.

Expand Down
4 changes: 3 additions & 1 deletion pymatgen/analysis/surface_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
if TYPE_CHECKING:
from typing_extensions import Self

from pymatgen.util.typing import Tuple3Ints

EV_PER_ANG2_TO_JOULES_PER_M2 = 16.0217656

__author__ = "Richard Tran"
Expand Down Expand Up @@ -566,7 +568,7 @@ def area_frac_vs_chempot_plot(
all_chempots = np.linspace(min(chempot_range), max(chempot_range), increments)

# initialize a dictionary of lists of fractional areas for each hkl
hkl_area_dict: dict[tuple[int, int, int], list[float]] = {}
hkl_area_dict: dict[Tuple3Ints, list[float]] = {}
for hkl in self.all_slab_entries:
hkl_area_dict[hkl] = []

Expand Down
2 changes: 1 addition & 1 deletion pymatgen/core/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class Composition(collections.abc.Hashable, collections.abc.Mapping, MSONable, S

# Special formula handling for peroxides and certain elements. This is so
# that formula output does not write LiO instead of Li2O2 for example.
special_formulas: ClassVar = dict(
special_formulas: ClassVar[dict[str, str]] = dict(
LiO="Li2O2",
NaO="Na2O2",
KO="K2O2",
Expand Down
Loading