Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/1448 Refactor random module #1508

Merged
merged 18 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 4 additions & 29 deletions heat/cluster/_kcluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,34 +109,9 @@ def _initialize_cluster_centers(self, x: DNDarray):

# initialize the centroids by randomly picking some of the points
if self.init == "random":
# Samples will be equally distributed drawn from all involved processes
_, displ, _ = x.comm.counts_displs_shape(shape=x.shape, axis=0)
centroids = ht.empty(
(self.n_clusters, x.shape[1]), split=None, device=x.device, comm=x.comm
)
if x.split is None or x.split == 0:
for i in range(self.n_clusters):
samplerange = (
x.gshape[0] // self.n_clusters * i,
x.gshape[0] // self.n_clusters * (i + 1),
)
sample = ht.random.randint(samplerange[0], samplerange[1]).item()
proc = 0
for p in range(x.comm.size):
if displ[p] > sample:
break
proc = p
xi = ht.zeros(x.shape[1], dtype=x.dtype)
if x.comm.rank == proc:
idx = sample - displ[proc]
xi = ht.array(x.lloc[idx, :], device=x.device, comm=x.comm)
xi.comm.Bcast(xi, root=proc)
centroids[i, :] = xi

else:
raise NotImplementedError("Not implemented for other splitting-axes")

self._cluster_centers = centroids
idx = ht.random.randint(0, x.shape[0] - 1, size=(self.n_clusters,), split=None)
centroids = x[idx, :]
self._cluster_centers = centroids if x.split == 1 else centroids.resplit_(None)

# directly passed centroids
elif isinstance(self.init, DNDarray):
Expand Down Expand Up @@ -172,7 +147,7 @@ def _initialize_cluster_centers(self, x: DNDarray):
D2 = distances.min(axis=1)
D2.resplit_(axis=None)
prob = D2 / D2.sum()
random_position = ht.random.rand().item()
random_position = ht.random.rand()
sample = 0
sum = 0
for j in range(len(prob)):
Expand Down
1 change: 0 additions & 1 deletion heat/cluster/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def fit(self, x: DNDarray) -> self:
# initialize the clustering
self._initialize_cluster_centers(x)
self._n_iter = 0
matching_centroids = ht.zeros((x.shape[0]), split=x.split, device=x.device, comm=x.comm)

# iteratively fit the points to the centroids
for epoch in range(self.max_iter):
Expand Down
2 changes: 1 addition & 1 deletion heat/cluster/kmedians.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def fit(self, x: DNDarray):
# initialize the clustering
self._initialize_cluster_centers(x)
self._n_iter = 0
matching_centroids = ht.zeros((x.shape[0]), split=x.split, device=x.device, comm=x.comm)

# iteratively fit the points to the centroids
for epoch in range(self.max_iter):
# increment the iteration count
Expand Down
2 changes: 1 addition & 1 deletion heat/cluster/kmedoids.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def fit(self, x: DNDarray):
# initialize the clustering
self._initialize_cluster_centers(x)
self._n_iter = 0
matching_centroids = ht.zeros((x.shape[0]), split=x.split, device=x.device, comm=x.comm)

# iteratively fit the points to the centroids
for epoch in range(self.max_iter):
# increment the iteration count
Expand Down
4 changes: 2 additions & 2 deletions heat/core/linalg/tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,13 +302,13 @@ def test_inv(self):
ainv = ht.linalg.inv(a)
i = ht.eye(a.shape, split=1, dtype=a.dtype)
# loss of precision in distributed floating-point ops
self.assertTrue(ht.allclose(a @ ainv, i, atol=1e-12))
self.assertTrue(ht.allclose(a @ ainv, i, atol=1e-10))

ht.random.seed(42)
a = ht.random.random((20, 20), dtype=ht.float64, split=0)
ainv = ht.linalg.inv(a)
i = ht.eye(a.shape, split=0, dtype=a.dtype)
self.assertTrue(ht.allclose(a @ ainv, i, atol=1e-12))
self.assertTrue(ht.allclose(a @ ainv, i, atol=1e-10))

with self.assertRaises(RuntimeError):
ht.linalg.inv(ht.array([1, 2, 3], split=0))
Expand Down
4 changes: 2 additions & 2 deletions heat/core/linalg/tests/test_svdtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_hsvd_rank_part1(self):
)
self.assertTrue(V_orth_err <= dtype_tol)
true_rel_err = ht.norm(U @ ht.diag(sigma) @ V.T - A) / ht.norm(A)
self.assertTrue(true_rel_err <= err_est)
self.assertTrue(true_rel_err <= err_est or true_rel_err < dtype_tol)
else:
self.assertEqual(hsvd_rk, 1)
self.assertEqual(ht.norm(U), 0)
Expand Down Expand Up @@ -96,7 +96,7 @@ def test_hsvd_rank_part1(self):
)
self.assertTrue(V_orth_err <= dtype_tol)
true_rel_err = ht.norm(U @ ht.diag(sigma) @ V.T - A) / ht.norm(A)
self.assertTrue(true_rel_err <= err_est)
self.assertTrue(true_rel_err <= err_est or true_rel_err < dtype_tol)
self.assertTrue(true_rel_err <= tol)
else:
self.assertEqual(hsvd_rk, 1)
Expand Down
Loading
Loading