Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding virial support #10

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
93 changes: 59 additions & 34 deletions dlext/include/DLExt.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
#ifndef LAMMPS_DLPACK_EXTENSION_H_
#define LAMMPS_DLPACK_EXTENSION_H_

#include "LAMMPSView.h"
#include "atom.h"
#include "fix.h"

#ifdef LMP_KOKKOS
#include "atom_kokkos.h"
Expand Down Expand Up @@ -34,6 +34,7 @@ static struct Images { } kImages;
static struct Tags { } kTags;
static struct TagsMap { } kTagsMap;
static struct Types { } kTypes;
static struct Virial { } kVirial;

static struct SecondDim { } kSecondDim;

Expand Down Expand Up @@ -61,6 +62,7 @@ inline void* opaque(const T* data)
return const_cast<void*>(data);
}

// if LAMMPS is built with KOKKOS, bind the PROPERTY struct to the corresponding ACCESSOR
#ifdef LMP_KOKKOS
#define DLEXT_OPAQUE_ATOM_KOKKOS(PROPERTY, ACCESSOR) \
inline void* opaque(const AtomKokkos* atom, PROPERTY) \
Expand All @@ -80,6 +82,7 @@ DLEXT_OPAQUE_ATOM_KOKKOS(Types, k_type)
#undef DLEXT_OPAQUE_ATOM_KOKKOS
#endif

// return the underlying pointers in LAMMPS (Property can be used as a tag, or actually bound to the KOKKOS accessor as above)
inline void* opaque(const Atom* atom, Positions) { return opaque(atom->x[0]); }
inline void* opaque(const Atom* atom, Velocities) { return opaque(atom->v[0]); }
inline void* opaque(const Atom* atom, Masses) { return opaque(atom->mass); }
Expand All @@ -91,22 +94,25 @@ inline void* opaque(const Atom* atom, TagsMap)
{
return opaque(const_cast<Atom*>(atom)->get_map_array());
}
inline void* opaque(const Fix* fix, Virial) { return opaque(fix->virial); }

template <typename Property>
inline void* opaque(const LAMMPSView& view, DLDeviceType device_type, Property p)
inline void* opaque(const Fix* fix, DLDeviceType device_type, Property p)
{
#ifdef LMP_KOKKOS
if (device_type == kDLCUDA)
return opaque(view.atom_kokkos_ptr(), p);
return opaque(fix->view.atom_kokkos_ptr(), p);
#endif
return opaque(view.atom_ptr(), p);
return opaque(fix->view.atom_ptr(), p);
}

inline DLDevice device_info(const LAMMPSView& view, DLDeviceType device_type)
// get the device info (id) from fix and device_type, return a DLDevice struct
inline DLDevice device_info(const Fix* fix, DLDeviceType device_type)
{
return DLDevice { device_type, view.device_id() };
return DLDevice { device_type, fix->view.device_id() };
}

// return the DLDataType code corresonding to the actual data type of the "Tag"
constexpr DLDataTypeCode dtype_code(Positions) { return kDLFloat; }
constexpr DLDataTypeCode dtype_code(Velocities) { return kDLFloat; }
constexpr DLDataTypeCode dtype_code(Masses) { return kDLFloat; }
Expand All @@ -115,7 +121,9 @@ constexpr DLDataTypeCode dtype_code(Images) { return kDLInt; }
constexpr DLDataTypeCode dtype_code(Tags) { return kDLInt; }
constexpr DLDataTypeCode dtype_code(TagsMap) { return kDLInt; }
constexpr DLDataTypeCode dtype_code(Types) { return kDLInt; }
constexpr DLDataTypeCode dtype_code(Virial) { return kDLFloat; }

// return the number of bits of the data type of a given PROPERTY
#define DLEXT_BITS_FLOAT_ARRAY(PROPERTY, TYPE) \
inline uint8_t bits(DLDeviceType device_type, PROPERTY) \
{ \
Expand All @@ -126,6 +134,7 @@ DLEXT_BITS_FLOAT_ARRAY(Positions, X_FLOAT)
DLEXT_BITS_FLOAT_ARRAY(Velocities, V_FLOAT)
DLEXT_BITS_FLOAT_ARRAY(Masses, LMP_FLOAT)
DLEXT_BITS_FLOAT_ARRAY(Forces, F_FLOAT)
DLEXT_BITS_FLOAT_ARRAY(Virial, F_FLOAT)

#undef DLEXT_BITS_FLOAT_ARRAY

Expand All @@ -151,47 +160,60 @@ inline DLDataType dtype(DLDeviceType device_type, Property p)
}

template <typename Property>
inline int64_t size(const LAMMPSView& view, Property)
inline int64_t size(const Fix* fix, Property)
{
return view.local_particle_number();
return fix->view.local_particle_number();
}
inline int64_t size(const LAMMPSView& view, Masses) { return view.atom_ptr()->ntypes + 1; }
inline int64_t size(const LAMMPSView& view, TagsMap) { return view.atom_ptr()->get_map_size(); }
inline int64_t size(const Fix* fix, Masses) { return fix->view.atom_ptr()->ntypes + 1; }
inline int64_t size(const Fix* fix, TagsMap) { return fix->view.atom_ptr()->get_map_size(); }
inline int64_t size(const Fix* fix, Virial) { return 6; }

template <typename Property>
inline int64_t size(const LAMMPSView& view, Property, SecondDim)
inline int64_t size(const Fix* fix, Property, SecondDim)
{
return 1;
}
inline int64_t size(const LAMMPSView& view, Positions, SecondDim) { return 3; }
inline int64_t size(const LAMMPSView& view, Velocities, SecondDim) { return 3; }
inline int64_t size(const LAMMPSView& view, Forces, SecondDim) { return 3; }
inline int64_t size(const Fix* fix, Positions, SecondDim) { return 3; }
inline int64_t size(const Fix* fix, Velocities, SecondDim) { return 3; }
inline int64_t size(const Fix* fix, Forces, SecondDim) { return 3; }

template <typename Property>
constexpr uint64_t offset(const LAMMPSView& view, Property p)
constexpr uint64_t offset(const Fix* fix, Property p)
{
return 0;
}

// a templated function for wrapping a C array given its data type and dimensions
// and returning a pointer to a DLPack tensor
template <typename Property>
DLManagedTensor* wrap(const LAMMPSView& view, Property property, ExecutionSpace exec_space)
DLManagedTensor* wrap(const Fix* fix, Property property, ExecutionSpace exec_space)
{
// get the device type of the fix (host or device)
auto device_type = fix->view.device_type(exec_space);

auto bridge = std::make_unique<DLDataBridge>();
bridge->tensor.manager_ctx = bridge.get();
bridge->tensor.deleter = delete_bridge;

// acquire the actual dltensor pointer
auto& dltensor = bridge->tensor.dl_tensor;
auto device_type = view.device_type(exec_space);
dltensor.data = opaque(view, device_type, property);
dltensor.device = device_info(view, device_type);

// fill in the dltensor struct
// get the underlying array/accessor of the given property and assign it to data (as void*)
dltensor.data = opaque(fix, device_type, property);
// get the device info from fix and device_type and assign it to device (as DLDevice)
dltensor.device = device_info(fix, device_type);
// get the data type of the underlying array (DLDataType) given the data type code and number of bits
dltensor.dtype = dtype(device_type, property);

// fill in the tensor shape (dimensions), strides and byte offsets
auto& shape = bridge->shape;
auto size2 = size(view, property, kSecondDim);
shape.push_back(size(view, property));
auto size2 = size(fix, property, kSecondDim);
shape.push_back(size(fix, property));
// if the array is 2D
if (size2 > 1)
shape.push_back(size2);

// strides between consecutive elements in each dim
auto& strides = bridge->strides;
strides.push_back(size2);
if (size2 > 1)
Expand All @@ -200,25 +222,28 @@ DLManagedTensor* wrap(const LAMMPSView& view, Property property, ExecutionSpace
dltensor.ndim = shape.size();
dltensor.shape = reinterpret_cast<std::int64_t*>(shape.data());
dltensor.strides = reinterpret_cast<std::int64_t*>(strides.data());
dltensor.byte_offset = offset(view, property);
dltensor.byte_offset = offset(fix, property);

return &(bridge.release()->tensor);
}

#define DLEXT_PROPERTY_FROM_VIEW(FN, SELECTOR) \
inline DLManagedTensor* FN(const LAMMPSView& view, ExecutionSpace space) \
// macro that returns a DLManagedTensor from fix for a given SELECTOR (Property)
#define DLEXT_PROPERTY_FROM_FIX(FN, SELECTOR) \
inline DLManagedTensor* FN(const Fix* fix, ExecutionSpace space) \
{ \
return wrap(view, SELECTOR, space); \
return wrap(fix, SELECTOR, space); \
}

DLEXT_PROPERTY_FROM_VIEW(positions, kPositions)
DLEXT_PROPERTY_FROM_VIEW(velocities, kVelocities)
DLEXT_PROPERTY_FROM_VIEW(masses, kMasses)
DLEXT_PROPERTY_FROM_VIEW(forces, kForces)
DLEXT_PROPERTY_FROM_VIEW(images, kImages)
DLEXT_PROPERTY_FROM_VIEW(tags, kTags)
DLEXT_PROPERTY_FROM_VIEW(tags_map, kTagsMap)
DLEXT_PROPERTY_FROM_VIEW(types, kTypes)
// finally, all the function instances to pack arrays into DLManagedTensor structs
DLEXT_PROPERTY_FROM_FIX(positions, kPositions)
DLEXT_PROPERTY_FROM_FIX(velocities, kVelocities)
DLEXT_PROPERTY_FROM_FIX(masses, kMasses)
DLEXT_PROPERTY_FROM_FIX(forces, kForces)
DLEXT_PROPERTY_FROM_FIX(images, kImages)
DLEXT_PROPERTY_FROM_FIX(tags, kTags)
DLEXT_PROPERTY_FROM_FIX(tags_map, kTagsMap)
DLEXT_PROPERTY_FROM_FIX(types, kTypes)
DLEXT_PROPERTY_FROM_FIX(virial, kVirial)

#undef DLEXT_PROPERTY

Expand Down
14 changes: 10 additions & 4 deletions dlext/include/FixDLExt.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
#ifndef DLEXT_SAMPLER_H_
#define DLEXT_SAMPLER_H_

#include "LAMMPSView.h"

#include "dlpack/dlpack.h"
#include "fix.h"

#include <functional>
#include "LAMMPSView.h"

namespace LAMMPS_NS
{
Expand All @@ -19,6 +18,7 @@ namespace dlext

using TimeStep = bigint; // bigint depends on how LAMMPS was built
using DLExtCallback = std::function<void(TimeStep)>;
using DLExtSetVirial = std::function<void(double*)>;

// } // Aliases

Expand All @@ -39,9 +39,15 @@ class DEFAULT_VISIBILITY FixDLExt : public Fix {
int setmask() override;
void post_force(int) override;
void set_callback(DLExtCallback& cb);
void set_virial_callback(DLExtSetVirial& cb);
void set_virial_global(int flag) { virial_global_flag = flag; }
void set_view(LAMMPSView _view);
LAMMPSView get_view() const;

private:
protected:
DLExtCallback callback = [](TimeStep) { };
DLExtSetVirial setVirial = [](double*) { };
LAMMPSView view;
};

void register_FixDLExt(LAMMPS* lmp);
Expand Down
41 changes: 39 additions & 2 deletions dlext/src/FixDLExt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,54 @@ FixDLExt::FixDLExt(LAMMPS* lmp, int narg, char** arg)
if (atom->map_style != Atom::MAP_ARRAY)
error->all(FLERR, "Fix dlext requires to map atoms as arrays");

// signal that this fix contributes to the global virial or not, default no
virial_global_flag = 0;

kokkosable = has_kokkos_cuda_enabled(lmp);
atomKK = dynamic_cast<AtomKokkos*>(atom);
execution_space = (on_host || !kokkosable) ? kOnHost : kOnDevice;
datamask_read = EMPTY_MASK;
datamask_modify = EMPTY_MASK;

// create an instance of LAMMPSView
// using default copy constructor and operator '=' of LAMMPSView works here
// because LAMMPSView simply encapsulates the LAMMPS pointers from lmp
view = LAMMPSView(lmp);
}

int FixDLExt::setmask()
{
return FixConst::POST_FORCE;
}

int FixDLExt::setmask() { return FixConst::POST_FORCE; }
void FixDLExt::post_force(int) { callback(update->ntimestep); }
void FixDLExt::post_force(int vflag)
{
// virial setup

v_init(vflag);

// invoke callback

callback(update->ntimestep);

// put the virial from the bias into this fix's member variable virial[6] (see fix.h)

if (virial_global_flag)
setVirial(virial);
}

// callback from the sampling method to add the biasing forces to the atoms
void FixDLExt::set_callback(DLExtCallback& cb) { callback = cb; }

// callback from the sampling method to set the virial contribution to the fix's virial
void FixDLExt::set_virial_callback(DLExtSetVirial& cb) { setVirial = cb; }

// set the LAMMPSView object
void FixDLExt::set_view(LAMMPSView _view) { view = _view; }

// get the LAMMPSView object
LAMMPSView FixDLExt::get_view() const { return view; }

void register_FixDLExt(LAMMPS* lmp)
{
auto fix_map = lmp->modify->fix_map;
Expand Down
2 changes: 1 addition & 1 deletion python/PyDLExt.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const char* const kDLTensorCapsuleName = "dltensor";
// See the DLPack Documentation https://dmlc.github.io/dlpack/latest/python_spec.html

template <PropertyGetter property>
inline PyCapsule enpycapsulate(const LAMMPSView& view, ExecutionSpace space)
inline PyCapsule enpycapsulate(const LAMMPS& view, ExecutionSpace space)
{
auto dl_managed_tensor = property(view, space);
return PyCapsule(
Expand Down
2 changes: 2 additions & 0 deletions python/lammps_dlext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ void export_FixDLExt(py::module& m)
return static_cast<FixDLExt*>(fix);
}))
.def("set_callback", &FixDLExt::set_callback)
.def("set_virial_callback", &FixDLExt::set_virial_callback)
.def_property("view", &FixDLExt::get_view, &FixDLExt::set_view);
;
}

Expand Down