Skip to content

Commit

Permalink
Annotate rdist; check for bad hellinger input
Browse files Browse the repository at this point in the history
  • Loading branch information
lmcinnes committed Nov 20, 2019
1 parent b17f0ab commit 738bc57
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
14 changes: 11 additions & 3 deletions umap/layouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@ def clip(val):
return val


@numba.njit("f4(f4[:],f4[:])", fastmath=True)
@numba.njit("f4(f4[::1],f4[::1])",
fastmath=True,
cache=True,

This comment has been minimized.

Copy link
@ekerazha

ekerazha Nov 25, 2019

Contributor

@lmcinnes
After this change I get the following error on Windows: RuntimeError: cannot cache function 'rdist': no locator available for file 'C:\\Users\\<user>\\Anaconda3\\lib\\site-packages\\umap_learn-0.4.0-py3.7.egg\\umap\\layouts.py'

This comment has been minimized.

Copy link
@lmcinnes

lmcinnes Nov 26, 2019

Author Owner

Hmm, I'm not sure about the issues on windows. You can probably just remove the "cache=True" from the annotation and everythign will work. You may also want to report this upstream to numba people.

This comment has been minimized.

Copy link
@ekerazha

ekerazha Dec 3, 2019

Contributor

@lmcinnes
This is the issue: numba/numba#4908
It happens when you install the package with python setup.py install (because it creates an egg). The workaround it to use pip install .

This comment has been minimized.

Copy link
@lmcinnes

lmcinnes Dec 3, 2019

Author Owner

Ah, thank you. I'll try to remember to add something to the FAQ about that.

locals={"result": numba.types.float32,
"diff": numba.types.float32,
"dim": numba.types.int32},
)
def rdist(x, y):
"""Reduced Euclidean distance.
Expand All @@ -40,8 +46,10 @@ def rdist(x, y):
The squared euclidean distance between x and y
"""
result = 0.0
for i in range(x.shape[0]):
result += (x[i] - y[i]) ** 2
dim = x.shape[0]
for i in range(dim):
diff = x[i] - y[i]
result += diff * diff

return result

Expand Down
3 changes: 3 additions & 0 deletions umap/umap_.py
Original file line number Diff line number Diff line change
Expand Up @@ -1503,6 +1503,9 @@ def fit(self, X, y=None):

self._validate_parameters()

if self.metric is "hellinger" and X.min() < 0:
raise ValueError("Metric 'hellinger' does not support negative values")

if self.verbose:
print(str(self))

Expand Down

0 comments on commit 738bc57

Please sign in to comment.