Skip to content

Commit

Permalink
add: get_box_grid
Browse files Browse the repository at this point in the history
  • Loading branch information
Johannes Steinmetzer committed Apr 10, 2024
1 parent b6e5928 commit f12295e
Showing 1 changed file with 105 additions and 6 deletions.
111 changes: 105 additions & 6 deletions pysisyphus/io/cube.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass
import functools
import itertools as it
from math import ceil
import tempfile
from typing import Tuple
Expand All @@ -16,10 +17,19 @@
from pysisyphus.wavefunction import Wavefunction


DEFAULT_MARGIN = 3.0


def get_mins_maxs(coords3d, margin=DEFAULT_MARGIN):
mins = coords3d.min(axis=0) - margin
maxs = coords3d.max(axis=0) + margin
return mins, maxs


@functools.singledispatch
def get_grid(coords3d: np.ndarray, num=10, offset=3.0):
minx, miny, minz = coords3d.min(axis=0) - offset
maxx, maxy, maxz = coords3d.max(axis=0) + offset
def get_grid(coords3d: np.ndarray, num=10, margin=DEFAULT_MARGIN):
(minx, miny, minz), (maxx, maxy, maxz) = get_mins_maxs(coords3d, margin)

X, Y, Z = np.mgrid[
minx : maxx : num * 1j,
miny : maxy : num * 1j,
Expand All @@ -35,9 +45,8 @@ def _(wf: Wavefunction, **kwargs):
return get_grid(wf.coords3d, **kwargs)


def get_grid_with_spacing(coords3d, spacing=0.30, margin=3.0):
minx, miny, minz = coords3d.min(axis=0) - margin
maxx, maxy, maxz = coords3d.max(axis=0) + margin
def get_grid_with_spacing(coords3d, spacing=0.30, margin=DEFAULT_MARGIN):
(minx, miny, minz), (maxx, maxy, maxz) = get_mins_maxs(coords3d, margin)
dx = maxx - minx
dy = maxy - miny
dz = maxz - minz
Expand All @@ -57,6 +66,96 @@ def get_grid_with_spacing(coords3d, spacing=0.30, margin=3.0):
return xyz, act_spacing, (nx, ny, nz)


@functools.singledispatch
def get_box_grid(coords3d: np.ndarray, num=100, margin=DEFAULT_MARGIN, edge_length=5):

(minx, miny, minz), (maxx, maxy, maxz) = get_mins_maxs(coords3d, margin)

try:
numx, numy, numz = num
# Cubic grid
except TypeError:
numx = numy = numz = num

try:
edge_lengthx, edge_lengthy, edge_lengthz = edge_length
# Same number of points along all dimensions
except TypeError:
edge_lengthx = edge_lengthy = edge_lengthz = edge_length

for n, e in zip((numx, numy, numz), (edge_lengthx, edge_lengthy, edge_lengthz)):
assert n % e == 0, f"'num={n}' must be a multiple of 'edge_length={e}'!"

# Grid extents
dx = maxx - minx
dy = maxy - miny
dz = maxz - minz
# Distance between points
spacing_x = dx / (numx - 1)
spacing_y = dy / (numy - 1)
spacing_z = dz / (numz - 1)
spacing = np.array((spacing_x, spacing_y, spacing_z))
trans_vec = np.array(
(
spacing_x * edge_lengthx,
spacing_y * edge_lengthy,
spacing_z * edge_lengthz,
)
)
ind_vec = np.array((edge_lengthx, edge_lengthy, edge_lengthz))
num_yz = numy * numz

def transform_inds(inds3d):
x, y, z = inds3d.T
return num_yz * x + numz * y + z

boxes_per_x = numx // edge_lengthx
boxes_per_y = numy // edge_lengthy
boxes_per_z = numz // edge_lengthz
box_size = edge_lengthx * edge_lengthy * edge_lengthz

# Build initial box
box3d = np.zeros((box_size, 3))
ind_box3d = np.zeros((box_size, 3), dtype=int)
xyz = np.zeros(3)
i = 0
for x in range(edge_lengthx):
xyz[0] = x * spacing_x
for y in range(edge_lengthy):
xyz[1] = y * spacing_y
for z in range(edge_lengthz):
xyz[2] = z * spacing_z
box3d[i] = xyz
ind_box3d[i] = (x, y, z)
i += 1
# Shift box to grid origin
box3d += np.array((minx, miny, minz))[None, :]

# Translate initial box along grid
grid3d = np.zeros((numx * numy * numz, 3))
# Also build & return an index array that sorts our box-grid into
# the usual grid order expected in cubes.
sort_inds = np.zeros(grid3d.shape[0], dtype=int)
xyz = np.zeros(3, dtype=int)
i = 0
for x in range(boxes_per_x):
xyz[0] = x
for y in range(boxes_per_y):
xyz[1] = y
for z in range(boxes_per_z):
xyz[2] = z
slc = slice(i * box_size, (i + 1) * box_size)
grid3d[slc] = box3d + (xyz * trans_vec)
sort_inds[slc] = transform_inds(ind_box3d + (xyz * ind_vec))
i += 1
return grid3d, spacing, (numx, numy, numz), sort_inds


@get_box_grid.register
def _(wf: Wavefunction, **kwargs):
return get_box_grid(wf.coords3d, **kwargs)


@dataclass
class Cube:
atoms: Tuple
Expand Down

0 comments on commit f12295e

Please sign in to comment.