Skip to content

Commit

Permalink
Add working thrust implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
tom91136 committed Sep 2, 2023
1 parent 7bc1c0d commit bea3076
Showing 1 changed file with 45 additions and 41 deletions.
86 changes: 45 additions & 41 deletions src/thrust/fasten.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,22 @@
template <size_t PPWI> class IMPL_CLS final : public Bude<PPWI> {

public:
static void fasten_main(const Params &p, std::vector<float> &results) {

thrust::device_vector<Atom> protein(p.protein);
thrust::device_vector<Atom> ligand(p.ligand);
thrust::device_vector<float> transforms_0(p.poses[0]);
thrust::device_vector<float> transforms_1(p.poses[1]);
thrust::device_vector<float> transforms_2(p.poses[2]);
thrust::device_vector<float> transforms_3(p.poses[3]);
thrust::device_vector<float> transforms_4(p.poses[4]);
thrust::device_vector<float> transforms_5(p.poses[5]);
thrust::device_vector<FFParams> forcefield(p.forcefield);
thrust::device_vector<float> energies(results.size());

thrust::device_vector<std::array<float, PPWI>> out(p.nposes() / PPWI);
static void fasten_main(const Params &p, //
thrust::device_vector<Atom> &protein, thrust::device_vector<Atom> &ligand,
thrust::device_vector<float> &transforms_0, thrust::device_vector<float> &transforms_1,
thrust::device_vector<float> &transforms_2, thrust::device_vector<float> &transforms_3,
thrust::device_vector<float> &transforms_4, thrust::device_vector<float> &transforms_5,
thrust::device_vector<FFParams> &forcefield, thrust::device_vector<float> &energies) {

thrust::counting_iterator<int> groups(0);
thrust::transform( //
groups, //
groups + (p.nposes() / PPWI), //
out.begin(), //
[=] __device__ __host__(const int group) {
thrust::for_each(
groups, groups + (p.nposes() / PPWI),
[natlig = p.natlig(), natpro = p.natpro(), //
protein = protein.data(), ligand = ligand.data(), //
transforms_0 = transforms_0.data(), transforms_1 = transforms_1.data(), transforms_2 = transforms_2.data(), //
transforms_3 = transforms_3.data(), transforms_4 = transforms_4.data(), transforms_5 = transforms_5.data(), //
forcefield = forcefield.data(), //
energies = energies.data()] __device__ __host__(const int group) {
std::array<std::array<Vec4<float>, 3>, PPWI> transform = {};
std::array<float, PPWI> etot = {};

Expand Down Expand Up @@ -69,8 +64,9 @@ template <size_t PPWI> class IMPL_CLS final : public Bude<PPWI> {
}

// Loop over ligand atoms
for (const Atom &l_atom : ligand) {
const FFParams l_params = forcefield[l_atom.type];
for (size_t il = 0; il < natlig; il++) {
const Atom &l_atom = ligand[il];
const FFParams &l_params = forcefield[l_atom.type];
const int lhphb_ltz = l_params.hphb < ZERO;
const int lhphb_gtz = l_params.hphb > ZERO;

Expand All @@ -87,9 +83,10 @@ template <size_t PPWI> class IMPL_CLS final : public Bude<PPWI> {
}

// Loop over protein atoms
for (const Atom &p_atom : protein) {
// // Load protein atom data
const FFParams p_params = forcefield[p_atom.type];
for (size_t ip = 0; ip < natpro; ip++) {
// Load protein atom data
const Atom &p_atom = protein[ip];
const FFParams &p_params = forcefield[p_atom.type];

const float radij = p_params.radius + l_params.radius;
const float r_radij = ONE / radij;
Expand Down Expand Up @@ -141,24 +138,11 @@ template <size_t PPWI> class IMPL_CLS final : public Bude<PPWI> {
}
}

////#pragma omp simd
// for (int l = 0; l < PPWI; l++) {
// etot[l] *= HALF;
// }
//
// return std::make_pair(group, etot);

// Write result
//#pragma omp simd
// for (int l = 0; l < PPWI; l++) {
// energies[group * PPWI + l] = etot[l] *= HALF;
// }

// Write result
#pragma omp simd
for (int l = 0; l < PPWI; l++) {
etot[l] *= HALF;
energies[group * PPWI + l] = etot[l] *= HALF;
}
return etot;
});
}

Expand Down Expand Up @@ -192,7 +176,11 @@ template <size_t PPWI> class IMPL_CLS final : public Bude<PPWI> {
checkError(IMPL_FN__(GetDeviceCount(&count)));
std::vector<Device> devices(count);
for (int i = 0; i < count; ++i) {
IMPL_FN__(DeviceProp) props{};
#if defined(__HIP_PLATFORM_HCC__) // can't use IMPL_TYPE__ here because of the extra _t suffix, thanks AMD
hipDeviceProp_t props{};
#else
cudaDeviceProp props{};
#endif
checkError(IMPL_FN__(GetDeviceProperties(&props, i)));
devices[i] = {i, std::string(props.name) + " (" + //
std::to_string(props.totalGlobalMem / 1024 / 1024) + "MB;" + //
Expand Down Expand Up @@ -227,13 +215,29 @@ template <size_t PPWI> class IMPL_CLS final : public Bude<PPWI> {

Sample sample(PPWI, wgsize, p.nposes());

thrust::device_vector<Atom> protein(p.protein);
thrust::device_vector<Atom> ligand(p.ligand);
thrust::device_vector<float> transforms_0(p.poses[0]);
thrust::device_vector<float> transforms_1(p.poses[1]);
thrust::device_vector<float> transforms_2(p.poses[2]);
thrust::device_vector<float> transforms_3(p.poses[3]);
thrust::device_vector<float> transforms_4(p.poses[4]);
thrust::device_vector<float> transforms_5(p.poses[5]);
thrust::device_vector<FFParams> forcefield(p.forcefield);
thrust::device_vector<float> energies(sample.energies.size());

for (size_t i = 0; i < p.totalIterations(); ++i) {
auto kernelStart = now();
fasten_main(p, sample.energies);
fasten_main(p, protein, ligand, //
transforms_0, transforms_1, transforms_2, //
transforms_3, transforms_4, transforms_5, //
forcefield, energies //
);
synchronise();
auto kernelEnd = now();
sample.kernelTimes.emplace_back(kernelStart, kernelEnd);
}
thrust::copy(energies.begin(), energies.end(), sample.energies.begin());

return sample;
};
Expand Down

0 comments on commit bea3076

Please sign in to comment.