Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

numba accelerated findDuplicateVectors #596

Merged
merged 12 commits into from
Jan 9, 2024
24 changes: 24 additions & 0 deletions hexrd/material/crystallography.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,9 @@ def __init__(self,
raise RuntimeError('have unparsed keyword arguments with keys: '
+ str(list(kwargs.keys())))

# This is only used to calculate the structure factor if invalidated
self.__unitcell = None

self.__calc()

return
Expand Down Expand Up @@ -935,7 +938,27 @@ def set_wavelength(self, wavelength):

wavelength = property(get_wavelength, set_wavelength, None)

def invalidate_structure_factor(self, unitcell):
# It can be expensive to compute the structure factor, so provide the
# option to just invalidate it, while providing a unit cell, so that
# it can be lazily computed from the unit cell.
self.__structFact = None
self._powder_intensity = None
self.__unitcell = unitcell

def _compute_sf_if_needed(self):
any_invalid = (
self.__structFact is None or
self._powder_intensity is None
)
if any_invalid and self.__unitcell is not None:
# Compute the structure factor first.
# This can be expensive to do, so we lazily compute it when needed.
hkls = self.getHKLs(allHKLs=True)
self.set_structFact(self.__unitcell.CalcXRSF(hkls))

def get_structFact(self):
self._compute_sf_if_needed()
return self.__structFact[~self.exclusions]

def set_structFact(self, structFact):
Expand All @@ -953,6 +976,7 @@ def set_structFact(self, structFact):

@property
def powder_intensity(self):
self._compute_sf_if_needed()
return self._powder_intensity[~self.exclusions]

@staticmethod
Expand Down
14 changes: 6 additions & 8 deletions hexrd/material/material.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def __init__(
self.reset_v0()

self._newPdata()
self.update_structure_factor()
self.invalidate_structure_factor()

def __str__(self):
"""String representation"""
Expand Down Expand Up @@ -291,7 +291,7 @@ def _newUnitcell(self):
def _hkls_changed(self):
# Call this when something happens that changes the hkls...
self._newPdata()
self.update_structure_factor()
self.invalidate_structure_factor()

def _newPdata(self):
"""Create a new plane data instance if the hkls have changed"""
Expand Down Expand Up @@ -405,10 +405,8 @@ def enable_hkls_below_tth(self, tth_threshold=90.0):

self._pData.exclusions = dflt_excl

def update_structure_factor(self):
hkls = self.planeData.getHKLs(allHKLs=True)
sf = self.unitcell.CalcXRSF(hkls)
self.planeData.set_structFact(sf)
def invalidate_structure_factor(self):
self.planeData.invalidate_structure_factor(self.unitcell)

def compute_powder_overlay(
self, ttharray=np.linspace(0, 80, 2000), fwhm=0.25, scale=1.0
Expand Down Expand Up @@ -1268,7 +1266,7 @@ def charge(self, vals):

self._charge = vals
# self._newUnitcell()
# self.update_structure_factor()
# self.invalidate_structure_factor()

@property
def absorption_length(self):
Expand Down Expand Up @@ -1390,7 +1388,7 @@ def _set_atomdata(self, atomtype, atominfo, U, charge):
self.charge = charge

self._newUnitcell()
self.update_structure_factor()
self.invalidate_structure_factor()

#
# ========== Methods
Expand Down
91 changes: 89 additions & 2 deletions hexrd/matrixutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@

from scipy import sparse

from hexrd.utils.decorators import numba_njit_if_available
from hexrd import constants
from hexrd.constants import USE_NUMBA
if USE_NUMBA:
import numba

from numba import prange

# module variables
sqr6i = 1./np.sqrt(6.)
Expand Down Expand Up @@ -582,7 +583,7 @@ def uniqueVectors(v, tol=1.0e-12):
return vSrt[:, ivInd[0:nUniq]]


def findDuplicateVectors(vec, tol=vTol, equivPM=False):
def findDuplicateVectors_old(vec, tol=vTol, equivPM=False):
"""
Find vectors in an array that are equivalent to within
a specified tolerance
Expand Down Expand Up @@ -682,6 +683,92 @@ def findDuplicateVectors(vec, tol=vTol, equivPM=False):

return eqv, uid

def findDuplicateVectors(vec, tol=vTol, equivPM=False):
eqv = _findduplicatevectors(vec, tol, equivPM)
uid = np.arange(0, vec.shape[1], dtype=np.int64)
mask = ~np.isnan(eqv)
idx = eqv[mask].astype(np.int64)
uid2 = list(np.delete(uid, idx))
eqv2 = []
for ii in range(eqv.shape[0]):
v = eqv[ii, mask[ii, :]]
if v.shape[0] > 0:
eqv2.append([ii] + list(v.astype(np.int64)))
return eqv2, uid2


@numba_njit_if_available(cache=True, nogil=True)
def _findduplicatevectors(vec, tol, equivPM):
"""
Find vectors in an array that are equivalent to within
a specified tolerance. code is accelerated by numba

USAGE:

eqv = DuplicateVectors(vec, *tol)

INPUT:

1) vec is n x m, a double array of m horizontally concatenated
n-dimensional vectors.
*2) tol is 1 x 1, a scalar tolerance. If not specified, the default
tolerance is 1e-14.
*3) set equivPM to True if vec and -vec
are to be treated as equivalent

OUTPUT:

1) eqv is 1 x p, a list of p equivalence relationships.

NOTES:

Each equivalence relationship is a 1 x q vector of indices that
represent the locations of duplicate columns/entries in the array
vec. For example:

| 1 2 2 2 1 2 7 |
vec = | |
| 2 3 5 3 2 3 3 |

eqv = [[1x2 double] [1x3 double]], where

eqv[0] = [0 4]
eqv[1] = [1 3 5]
"""

if equivPM:
vec2 = -vec.copy()

n = vec.shape[0]
m = vec.shape[1]

eqv = np.zeros((m, m), dtype=np.float64)
eqv[:] = np.nan
eqv_elem_master = []

for ii in range(m):
ctr = 0
eqv_elem = np.zeros((m, ), dtype=np.int64)
for jj in range(ii+1, m):
if not jj in eqv_elem_master:
if equivPM:
diff = np.sum(np.abs(vec[:, ii]-vec2[:, jj]))
diff2 = np.sum(np.abs(vec[:, ii]-vec[:, jj]))
if diff < tol or diff2 < tol:
eqv_elem[ctr] = jj
eqv_elem_master.append(jj)
ctr += 1
else:
diff = np.sum(np.abs(vec[:, ii]-vec[:, jj]))
if diff < tol:
eqv_elem[ctr] = jj
eqv_elem_master.append(jj)
ctr += 1

for kk in range(ctr):
eqv[ii, kk] = eqv_elem[kk]

return eqv

def normvec(v):
mag = np.linalg.norm(v)
Expand Down
18 changes: 18 additions & 0 deletions hexrd/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from collections import OrderedDict
from functools import wraps

import numba
import numpy as np
import xxhash

Expand Down Expand Up @@ -139,3 +140,20 @@ def decorator(func):
from numba import prange
else:
prange = range


# A decorator to limit the number of numba threads
def limit_numba_threads(max_threads):
def decorator(func):
def wrapper(*args, **kwargs):
prev_num_threads = numba.get_num_threads()
new_num_threads = min(prev_num_threads, max_threads)
numba.set_num_threads(new_num_threads)
try:
return func(*args, **kwargs)
finally:
numba.set_num_threads(prev_num_threads)

return wrapper

return decorator