Skip to content

Commit

Permalink
brillouin_zone_3d add keyword axes_vectors: dict[Literal["shaft", "co…
Browse files Browse the repository at this point in the history
…ne"], dict[str, Any]] | Literal[False] | None
  • Loading branch information
janosh committed Nov 29, 2024
1 parent 1f00492 commit 9a9a0a9
Showing 1 changed file with 68 additions and 0 deletions.
68 changes: 68 additions & 0 deletions pymatviz/brillouin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ def brillouin_zone_3d(
label_kwargs: dict[str, Any] | None = None,
# High symmetry path styling
path_kwargs: dict[str, Any] | Literal[False] | None = None,
# Coordinate axes styling
axes_vectors: dict[Literal["shaft", "cone"], dict[str, Any]]
| Literal[False]
| None = None,
) -> go.Figure:
"""Generate a 3D plotly figure of the first Brillouin zone for a given structure.
Expand All @@ -33,6 +37,10 @@ def brillouin_zone_3d(
# High symmetry path styling
path_kwargs (dict | Literal[False]): Styling for paths. Set to False to disable
plotting paths.
# Coordinate axes styling
axes_vectors (dict | False): Keywords for coordinate axes vectors. Split into
2 sub dicts axes_vectors={shaft: {...}, cone: {...}}. Use nested key
shaft.len to control vector length. Set to False to disable axes plotting.
Returns:
go.Figure: A plotly figure containing the first Brillouin zone
Expand Down Expand Up @@ -89,6 +97,7 @@ def brillouin_zone_3d(
colors = ["red", "green", "blue"]
labels = ["b₁", "b₂", "b₃"]

# Plot reciprocal lattice vectors
for idx, vec in enumerate(k_space_cell):
start, end = np.zeros(3), vec # Vector points

Expand Down Expand Up @@ -250,4 +259,63 @@ def brillouin_zone_3d(
camera=dict(eye=eye_position),
)

# Add coordinate axes if requested
if axes_vectors is not False:
default_cone_kwargs = dict(
sizemode="absolute",
sizeref=0.2, # Smaller arrow heads than reciprocal vectors
showscale=False,
opacity=0.7,
len=2,
)
cone_kwargs = default_cone_kwargs | (axes_vectors or {}).get("cone", {})
vec_len = cone_kwargs.pop("len", 2)
default_shaft_kwargs = dict(
mode="lines",
line=dict(color="gray", width=3),
showlegend=False,
hoverinfo="none",
)
shaft_kwargs = default_shaft_kwargs | (axes_vectors or {}).get("shaft", {})

# Coordinates vectors for x, y, z axes
for vector, label in (([1, 0, 0], "x"), ([0, 1, 0], "y"), ([0, 0, 1], "z")):
start, end = np.zeros(3), np.array(vector) * vec_len

# Add vector shaft
fig.add_scatter3d(
x=[start[0], end[0]],
y=[start[1], end[1]],
z=[start[2], end[2]],
**shaft_kwargs,
)

# Add arrow head
arrow_dir = 0.25 * (end - start)
fig.add_cone(
x=[end[0]],
y=[end[1]],
z=[end[2]],
u=[arrow_dir[0]],
v=[arrow_dir[1]],
w=[arrow_dir[2]],
anchor="cm",
hoverinfo="none",
colorscale=["gray", "gray"],
name=label,
**cone_kwargs,
)

# Add label
fig.add_scatter3d(
x=[end[0]],
y=[end[1]],
z=[end[2]],
mode="text",
text=[label],
textposition="top center",
textfont=dict(size=16, color="gray"),
showlegend=False,
)

return fig

0 comments on commit 9a9a0a9

Please sign in to comment.