From 9a9a0a92a2abaebffe51ddfb9ce81b8c0adfc48e Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Fri, 29 Nov 2024 15:37:47 -0500 Subject: [PATCH] brillouin_zone_3d add keyword axes_vectors: dict[Literal["shaft", "cone"], dict[str, Any]] | Literal[False] | None --- pymatviz/brillouin.py | 68 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/pymatviz/brillouin.py b/pymatviz/brillouin.py index efdc2cc7..cf072b83 100644 --- a/pymatviz/brillouin.py +++ b/pymatviz/brillouin.py @@ -19,6 +19,10 @@ def brillouin_zone_3d( label_kwargs: dict[str, Any] | None = None, # High symmetry path styling path_kwargs: dict[str, Any] | Literal[False] | None = None, + # Coordinate axes styling + axes_vectors: dict[Literal["shaft", "cone"], dict[str, Any]] + | Literal[False] + | None = None, ) -> go.Figure: """Generate a 3D plotly figure of the first Brillouin zone for a given structure. @@ -33,6 +37,10 @@ def brillouin_zone_3d( # High symmetry path styling path_kwargs (dict | Literal[False]): Styling for paths. Set to False to disable plotting paths. + # Coordinate axes styling + axes_vectors (dict | False): Keywords for coordinate axes vectors. Split into + 2 sub dicts axes_vectors={shaft: {...}, cone: {...}}. Use nested key + shaft.len to control vector length. Set to False to disable axes plotting. Returns: go.Figure: A plotly figure containing the first Brillouin zone @@ -89,6 +97,7 @@ def brillouin_zone_3d( colors = ["red", "green", "blue"] labels = ["b₁", "b₂", "b₃"] + # Plot reciprocal lattice vectors for idx, vec in enumerate(k_space_cell): start, end = np.zeros(3), vec # Vector points @@ -250,4 +259,63 @@ def brillouin_zone_3d( camera=dict(eye=eye_position), ) + # Add coordinate axes if requested + if axes_vectors is not False: + default_cone_kwargs = dict( + sizemode="absolute", + sizeref=0.2, # Smaller arrow heads than reciprocal vectors + showscale=False, + opacity=0.7, + len=2, + ) + cone_kwargs = default_cone_kwargs | (axes_vectors or {}).get("cone", {}) + vec_len = cone_kwargs.pop("len", 2) + default_shaft_kwargs = dict( + mode="lines", + line=dict(color="gray", width=3), + showlegend=False, + hoverinfo="none", + ) + shaft_kwargs = default_shaft_kwargs | (axes_vectors or {}).get("shaft", {}) + + # Coordinates vectors for x, y, z axes + for vector, label in (([1, 0, 0], "x"), ([0, 1, 0], "y"), ([0, 0, 1], "z")): + start, end = np.zeros(3), np.array(vector) * vec_len + + # Add vector shaft + fig.add_scatter3d( + x=[start[0], end[0]], + y=[start[1], end[1]], + z=[start[2], end[2]], + **shaft_kwargs, + ) + + # Add arrow head + arrow_dir = 0.25 * (end - start) + fig.add_cone( + x=[end[0]], + y=[end[1]], + z=[end[2]], + u=[arrow_dir[0]], + v=[arrow_dir[1]], + w=[arrow_dir[2]], + anchor="cm", + hoverinfo="none", + colorscale=["gray", "gray"], + name=label, + **cone_kwargs, + ) + + # Add label + fig.add_scatter3d( + x=[end[0]], + y=[end[1]], + z=[end[2]], + mode="text", + text=[label], + textposition="top center", + textfont=dict(size=16, color="gray"), + showlegend=False, + ) + return fig