Skip to content

Commit

Permalink
Fix prob for kmeans||.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jul 20, 2018
1 parent e45ee04 commit 6a8669a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 28 deletions.
49 changes: 22 additions & 27 deletions src/gpu/kmeans/kmeans_init.cu
Original file line number Diff line number Diff line change
Expand Up @@ -235,29 +235,17 @@ struct PairWiseDistanceOp {

kernel::construct_distance_pairs_kernel<<<
dim3(GpuInfo::ins().blocks(32), div_roundup(_centroids.rows(), 16)),
dim3(16, 16)>>>(
dim3(16, 16)>>>( // FIXME: Tune this.
distance_pairs_.k_param(),
data_dot_.k_param(),
centroids_dot_.k_param());
CUDA_CHECK(cudaGetLastError());
std::cout << std::endl;
std::cout << "in distance op" << std::endl;
std::cout << distance_pairs_ << std::endl;
cublasHandle_t handle = GpuInfo::ins().cublas_handle();
T alpha = -2.0;
T beta = 1.0;
std::cout << "data.shape: " << _data.rows() << ", " << _data.cols() <<
"\tcentroids.shape: " << _centroids.rows() << ", " << _centroids.cols() <<
"\tdp.shape: " << distance_pairs_.rows() << ", " << distance_pairs_.cols() <<
std::endl;
std::cout << _data << std::endl;
std::cout << _centroids << std::endl;
std::cout << _centroids.dev_ptr() << std::endl;
Blas::gemm(
handle,
Expand All @@ -270,12 +258,11 @@ struct PairWiseDistanceOp {
&beta,
distance_pairs_.dev_ptr(), distance_pairs_.rows());
std::cout << distance_pairs_ << std::endl;
std::cout << "return" << std::endl;
return distance_pairs_;
}
};
template <typename T>
KmMatrix<T> KmeansLlInit<T>::probability(
KmMatrix<T>& _data, KmMatrix<T>& _centroids) {
Expand All @@ -301,13 +288,20 @@ KmMatrix<T> KmeansLlInit<T>::probability(
CUDA_CHECK(cudaGetLastError());
std::cout << min_distances << std::endl;
T cost = SumOp<T>().sum(min_distances);
std::cout << "cost: " << cost << std::endl;
// Re-use min_distances to store prob
MulOp<T> mul_op;
mul_op.mul(min_distances, min_distances, 1 / cost * over_sample_ * k_);
return min_distances;
KmMatrix<T> prob (min_distances.rows(), 1);
mul_op.mul(prob, min_distances, (over_sample_ * k_ * 1) / cost);
std::cout << prob << std::endl;
return prob;
}
Expand Down Expand Up @@ -357,19 +351,19 @@ KmMatrix<T> KmeansLlInit<T>::sample_centroids(KmMatrix<T>& _data, KmMatrix<T>& _
T prob_x = prob_ptr[idx];
return prob_x > thresh;
});
std::cout << std::endl;
return new_centroids;
}
template <typename T>
KmMatrix<T>
KmeansLlInit<T>::operator()(KmMatrix<T>& _data, size_t k) {
KmeansLlInit<T>::operator()(KmMatrix<T>& _data, size_t _k) {
if (seed_ < 0) {
std::random_device rd;
seed_ = rd();
}
k_ = k;
k_ = _k;
std::mt19937 generator(0);
Expand All @@ -386,14 +380,15 @@ KmeansLlInit<T>::operator()(KmMatrix<T>& _data, size_t k) {
KmMatrix<T> prob = probability(_data, centroids);
T cost = SumOp<T>().sum(prob);
// FIXME
// for (size_t i = 0; i < std::log(cost); ++i) {
for (size_t i = 0; i < 1; ++i) {
std::cout << "looping" << std::endl;
KmMatrix<T> new_centroids = sample_centroids(_data, centroids);
centroids = stack(centroids, new_centroids, KmMatrixDim::ROW);
for (size_t i = 0; i < std::log(cost); ++i) {
prob = probability(_data, centroids);
KmMatrix<T> new_centroids = sample_centroids(_data, prob);
new_centroids.set_name ("new centroids");
std::cout << new_centroids << std::endl;
centroids = stack(centroids, new_centroids, KmMatrixDim::ROW);
centroids.set_name ("centroids");
std::cout << centroids << std::endl;
}
if (centroids.rows() < k_) {
Expand All @@ -407,7 +402,7 @@ KmeansLlInit<T>::operator()(KmMatrix<T>& _data, size_t k) {
#define INSTANTIATE(T) \
template KmMatrix<T> KmeansLlInit<T>::operator()( \
KmMatrix<T>& data, size_t k); \
KmMatrix<T>& _data, size_t _k); \
template KmMatrix<T> KmeansLlInit<T>::probability(KmMatrix<T>& data, \
KmMatrix<T>& centroids); \
template KmMatrix<T> KmeansLlInit<T>::sample_centroids( \
Expand Down
3 changes: 2 additions & 1 deletion src/gpu/kmeans/kmeans_init.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ struct KmeansLlInit : public KmeansInitBase<T> {
KmMatrix<T> probability(KmMatrix<T>& data, KmMatrix<T>& centroids);

public:
KmeansLlInit () : over_sample_ (2.0), seed_ (0), k_(0) {
KmeansLlInit (T _over_sample=2.0) :
over_sample_ (_over_sample), seed_ (0), k_(0) {
data_dot_.set_name ("data_dot");
distance_pairs_.set_name ("distance pairs");
}
Expand Down

0 comments on commit 6a8669a

Please sign in to comment.