Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

numba accelerated findDuplicateVectors #596

Merged
merged 12 commits into from
Jan 9, 2024
91 changes: 89 additions & 2 deletions hexrd/matrixutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@

from scipy import sparse

from hexrd.utils.decorators import limit_numba_threads, numba_njit_if_available
from hexrd import constants
from hexrd.constants import USE_NUMBA
if USE_NUMBA:
import numba

from numba import prange

# module variables
sqr6i = 1./np.sqrt(6.)
Expand Down Expand Up @@ -582,7 +583,7 @@ def uniqueVectors(v, tol=1.0e-12):
return vSrt[:, ivInd[0:nUniq]]


def findDuplicateVectors(vec, tol=vTol, equivPM=False):
def findDuplicateVectors_old(vec, tol=vTol, equivPM=False):
"""
Find vectors in an array that are equivalent to within
a specified tolerance
Expand Down Expand Up @@ -682,6 +683,92 @@ def findDuplicateVectors(vec, tol=vTol, equivPM=False):

return eqv, uid

def findDuplicateVectors(vec, tol=vTol, equivPM=False):
eqv = _findduplicatevectors(vec, tol, equivPM)
uid = np.arange(0, vec.shape[1], dtype=np.int64)
mask = ~np.isnan(eqv)
idx = eqv[mask].astype(np.int64)
uid2 = list(np.delete(uid, idx))
eqv2 = []
for ii in range(eqv.shape[0]):
v = eqv[ii, mask[ii, :]]
if v.shape[0] > 0:
eqv2.append([ii] + list(v.astype(np.int64)))
return eqv2, uid2


# We found that too many threads causes allocator contention,
# so limit the number of threads here to just 8.
@limit_numba_threads(8)
@numba_njit_if_available(cache=True, nogil=True, parallel=True)
def _findduplicatevectors(vec, tol, equivPM):
"""
Find vectors in an array that are equivalent to within
a specified tolerance. code is accelerated by numba

USAGE:

eqv = DuplicateVectors(vec, *tol)

INPUT:

1) vec is n x m, a double array of m horizontally concatenated
n-dimensional vectors.
*2) tol is 1 x 1, a scalar tolerance. If not specified, the default
tolerance is 1e-14.
*3) set equivPM to True if vec and -vec
are to be treated as equivalent

OUTPUT:

1) eqv is 1 x p, a list of p equivalence relationships.

NOTES:

Each equivalence relationship is a 1 x q vector of indices that
represent the locations of duplicate columns/entries in the array
vec. For example:

| 1 2 2 2 1 2 7 |
vec = | |
| 2 3 5 3 2 3 3 |

eqv = [[1x2 double] [1x3 double]], where

eqv[0] = [0 4]
eqv[1] = [1 3 5]
"""

if equivPM:
vec2 = -vec.copy()

n = vec.shape[0]
m = vec.shape[1]

eqv = np.zeros((m, m), dtype=np.float64)
eqv[:] = np.nan

for ii in prange(m):
ctr = 0
eqv_elem = np.zeros((m, ), dtype=np.int64)

for jj in prange(ii+1, m):
if equivPM:
diff = np.sum(np.abs(vec[:, ii]-vec2[:, jj]))
diff2 = np.sum(np.abs(vec[:, ii]-vec[:, jj]))
if diff < tol or diff2 < tol:
eqv_elem[ctr] = jj
ctr += 1
else:
diff = np.sum(np.abs(vec[:, ii]-vec[:, jj]))
if diff < tol:
eqv_elem[ctr] = jj
ctr += 1

for kk in range(ctr):
eqv[ii, kk] = eqv_elem[kk]

return eqv

def normvec(v):
mag = np.linalg.norm(v)
Expand Down
18 changes: 18 additions & 0 deletions hexrd/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from collections import OrderedDict
from functools import wraps

import numba
import numpy as np
import xxhash

Expand Down Expand Up @@ -139,3 +140,20 @@ def decorator(func):
from numba import prange
else:
prange = range


# A decorator to limit the number of numba threads
def limit_numba_threads(max_threads):
def decorator(func):
def wrapper(*args, **kwargs):
prev_num_threads = numba.get_num_threads()
new_num_threads = min(prev_num_threads, max_threads)
numba.set_num_threads(new_num_threads)
try:
return func(*args, **kwargs)
finally:
numba.set_num_threads(prev_num_threads)

return wrapper

return decorator