Skip to content

Commit

Permalink
Merge pull request #596 from HEXRD/numba-findduplicate
Browse files Browse the repository at this point in the history
numba accelerated findDuplicateVectors
  • Loading branch information
saransh13 authored Jan 9, 2024
2 parents 9b4129d + bea9b1c commit bb30312
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 10 deletions.
24 changes: 24 additions & 0 deletions hexrd/material/crystallography.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,9 @@ def __init__(self,
raise RuntimeError('have unparsed keyword arguments with keys: '
+ str(list(kwargs.keys())))

# This is only used to calculate the structure factor if invalidated
self.__unitcell = None

self.__calc()

return
Expand Down Expand Up @@ -935,7 +938,27 @@ def set_wavelength(self, wavelength):

wavelength = property(get_wavelength, set_wavelength, None)

def invalidate_structure_factor(self, unitcell):
# It can be expensive to compute the structure factor, so provide the
# option to just invalidate it, while providing a unit cell, so that
# it can be lazily computed from the unit cell.
self.__structFact = None
self._powder_intensity = None
self.__unitcell = unitcell

def _compute_sf_if_needed(self):
any_invalid = (
self.__structFact is None or
self._powder_intensity is None
)
if any_invalid and self.__unitcell is not None:
# Compute the structure factor first.
# This can be expensive to do, so we lazily compute it when needed.
hkls = self.getHKLs(allHKLs=True)
self.set_structFact(self.__unitcell.CalcXRSF(hkls))

def get_structFact(self):
self._compute_sf_if_needed()
return self.__structFact[~self.exclusions]

def set_structFact(self, structFact):
Expand All @@ -953,6 +976,7 @@ def set_structFact(self, structFact):

@property
def powder_intensity(self):
self._compute_sf_if_needed()
return self._powder_intensity[~self.exclusions]

@staticmethod
Expand Down
14 changes: 6 additions & 8 deletions hexrd/material/material.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def __init__(
self.reset_v0()

self._newPdata()
self.update_structure_factor()
self.invalidate_structure_factor()

def __str__(self):
"""String representation"""
Expand Down Expand Up @@ -291,7 +291,7 @@ def _newUnitcell(self):
def _hkls_changed(self):
# Call this when something happens that changes the hkls...
self._newPdata()
self.update_structure_factor()
self.invalidate_structure_factor()

def _newPdata(self):
"""Create a new plane data instance if the hkls have changed"""
Expand Down Expand Up @@ -405,10 +405,8 @@ def enable_hkls_below_tth(self, tth_threshold=90.0):

self._pData.exclusions = dflt_excl

def update_structure_factor(self):
hkls = self.planeData.getHKLs(allHKLs=True)
sf = self.unitcell.CalcXRSF(hkls)
self.planeData.set_structFact(sf)
def invalidate_structure_factor(self):
self.planeData.invalidate_structure_factor(self.unitcell)

def compute_powder_overlay(
self, ttharray=np.linspace(0, 80, 2000), fwhm=0.25, scale=1.0
Expand Down Expand Up @@ -1268,7 +1266,7 @@ def charge(self, vals):

self._charge = vals
# self._newUnitcell()
# self.update_structure_factor()
# self.invalidate_structure_factor()

@property
def absorption_length(self):
Expand Down Expand Up @@ -1390,7 +1388,7 @@ def _set_atomdata(self, atomtype, atominfo, U, charge):
self.charge = charge

self._newUnitcell()
self.update_structure_factor()
self.invalidate_structure_factor()

#
# ========== Methods
Expand Down
91 changes: 89 additions & 2 deletions hexrd/matrixutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@

from scipy import sparse

from hexrd.utils.decorators import numba_njit_if_available
from hexrd import constants
from hexrd.constants import USE_NUMBA
if USE_NUMBA:
import numba

from numba import prange

# module variables
sqr6i = 1./np.sqrt(6.)
Expand Down Expand Up @@ -582,7 +583,7 @@ def uniqueVectors(v, tol=1.0e-12):
return vSrt[:, ivInd[0:nUniq]]


def findDuplicateVectors(vec, tol=vTol, equivPM=False):
def findDuplicateVectors_old(vec, tol=vTol, equivPM=False):
"""
Find vectors in an array that are equivalent to within
a specified tolerance
Expand Down Expand Up @@ -682,6 +683,92 @@ def findDuplicateVectors(vec, tol=vTol, equivPM=False):

return eqv, uid

def findDuplicateVectors(vec, tol=vTol, equivPM=False):
eqv = _findduplicatevectors(vec, tol, equivPM)
uid = np.arange(0, vec.shape[1], dtype=np.int64)
mask = ~np.isnan(eqv)
idx = eqv[mask].astype(np.int64)
uid2 = list(np.delete(uid, idx))
eqv2 = []
for ii in range(eqv.shape[0]):
v = eqv[ii, mask[ii, :]]
if v.shape[0] > 0:
eqv2.append([ii] + list(v.astype(np.int64)))
return eqv2, uid2


@numba_njit_if_available(cache=True, nogil=True)
def _findduplicatevectors(vec, tol, equivPM):
"""
Find vectors in an array that are equivalent to within
a specified tolerance. code is accelerated by numba
USAGE:
eqv = DuplicateVectors(vec, *tol)
INPUT:
1) vec is n x m, a double array of m horizontally concatenated
n-dimensional vectors.
*2) tol is 1 x 1, a scalar tolerance. If not specified, the default
tolerance is 1e-14.
*3) set equivPM to True if vec and -vec
are to be treated as equivalent
OUTPUT:
1) eqv is 1 x p, a list of p equivalence relationships.
NOTES:
Each equivalence relationship is a 1 x q vector of indices that
represent the locations of duplicate columns/entries in the array
vec. For example:
| 1 2 2 2 1 2 7 |
vec = | |
| 2 3 5 3 2 3 3 |
eqv = [[1x2 double] [1x3 double]], where
eqv[0] = [0 4]
eqv[1] = [1 3 5]
"""

if equivPM:
vec2 = -vec.copy()

n = vec.shape[0]
m = vec.shape[1]

eqv = np.zeros((m, m), dtype=np.float64)
eqv[:] = np.nan
eqv_elem_master = []

for ii in range(m):
ctr = 0
eqv_elem = np.zeros((m, ), dtype=np.int64)
for jj in range(ii+1, m):
if not jj in eqv_elem_master:
if equivPM:
diff = np.sum(np.abs(vec[:, ii]-vec2[:, jj]))
diff2 = np.sum(np.abs(vec[:, ii]-vec[:, jj]))
if diff < tol or diff2 < tol:
eqv_elem[ctr] = jj
eqv_elem_master.append(jj)
ctr += 1
else:
diff = np.sum(np.abs(vec[:, ii]-vec[:, jj]))
if diff < tol:
eqv_elem[ctr] = jj
eqv_elem_master.append(jj)
ctr += 1

for kk in range(ctr):
eqv[ii, kk] = eqv_elem[kk]

return eqv

def normvec(v):
mag = np.linalg.norm(v)
Expand Down
18 changes: 18 additions & 0 deletions hexrd/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from collections import OrderedDict
from functools import wraps

import numba
import numpy as np
import xxhash

Expand Down Expand Up @@ -139,3 +140,20 @@ def decorator(func):
from numba import prange
else:
prange = range


# A decorator to limit the number of numba threads
def limit_numba_threads(max_threads):
def decorator(func):
def wrapper(*args, **kwargs):
prev_num_threads = numba.get_num_threads()
new_num_threads = min(prev_num_threads, max_threads)
numba.set_num_threads(new_num_threads)
try:
return func(*args, **kwargs)
finally:
numba.set_num_threads(prev_num_threads)

return wrapper

return decorator

0 comments on commit bb30312

Please sign in to comment.