diff --git a/dlext/include/DLExt.h b/dlext/include/DLExt.h index 12bcc3e..708a200 100644 --- a/dlext/include/DLExt.h +++ b/dlext/include/DLExt.h @@ -4,8 +4,8 @@ #ifndef LAMMPS_DLPACK_EXTENSION_H_ #define LAMMPS_DLPACK_EXTENSION_H_ -#include "LAMMPSView.h" #include "atom.h" +#include "fix.h" #ifdef LMP_KOKKOS #include "atom_kokkos.h" @@ -34,6 +34,7 @@ static struct Images { } kImages; static struct Tags { } kTags; static struct TagsMap { } kTagsMap; static struct Types { } kTypes; +static struct Virial { } kVirial; static struct SecondDim { } kSecondDim; @@ -61,6 +62,7 @@ inline void* opaque(const T* data) return const_cast(data); } +// if LAMMPS is built with KOKKOS, bind the PROPERTY struct to the corresponding ACCESSOR #ifdef LMP_KOKKOS #define DLEXT_OPAQUE_ATOM_KOKKOS(PROPERTY, ACCESSOR) \ inline void* opaque(const AtomKokkos* atom, PROPERTY) \ @@ -80,6 +82,7 @@ DLEXT_OPAQUE_ATOM_KOKKOS(Types, k_type) #undef DLEXT_OPAQUE_ATOM_KOKKOS #endif +// return the underlying pointers in LAMMPS (Property can be used as a tag, or actually bound to the KOKKOS accessor as above) inline void* opaque(const Atom* atom, Positions) { return opaque(atom->x[0]); } inline void* opaque(const Atom* atom, Velocities) { return opaque(atom->v[0]); } inline void* opaque(const Atom* atom, Masses) { return opaque(atom->mass); } @@ -91,22 +94,25 @@ inline void* opaque(const Atom* atom, TagsMap) { return opaque(const_cast(atom)->get_map_array()); } +inline void* opaque(const Fix* fix, Virial) { return opaque(fix->virial); } template -inline void* opaque(const LAMMPSView& view, DLDeviceType device_type, Property p) +inline void* opaque(const Fix* fix, DLDeviceType device_type, Property p) { #ifdef LMP_KOKKOS if (device_type == kDLCUDA) - return opaque(view.atom_kokkos_ptr(), p); + return opaque(fix->view.atom_kokkos_ptr(), p); #endif - return opaque(view.atom_ptr(), p); + return opaque(fix->view.atom_ptr(), p); } -inline DLDevice device_info(const LAMMPSView& view, DLDeviceType device_type) +// get the device info (id) from fix and device_type, return a DLDevice struct +inline DLDevice device_info(const Fix* fix, DLDeviceType device_type) { - return DLDevice { device_type, view.device_id() }; + return DLDevice { device_type, fix->view.device_id() }; } +// return the DLDataType code corresonding to the actual data type of the "Tag" constexpr DLDataTypeCode dtype_code(Positions) { return kDLFloat; } constexpr DLDataTypeCode dtype_code(Velocities) { return kDLFloat; } constexpr DLDataTypeCode dtype_code(Masses) { return kDLFloat; } @@ -115,7 +121,9 @@ constexpr DLDataTypeCode dtype_code(Images) { return kDLInt; } constexpr DLDataTypeCode dtype_code(Tags) { return kDLInt; } constexpr DLDataTypeCode dtype_code(TagsMap) { return kDLInt; } constexpr DLDataTypeCode dtype_code(Types) { return kDLInt; } +constexpr DLDataTypeCode dtype_code(Virial) { return kDLFloat; } +// return the number of bits of the data type of a given PROPERTY #define DLEXT_BITS_FLOAT_ARRAY(PROPERTY, TYPE) \ inline uint8_t bits(DLDeviceType device_type, PROPERTY) \ { \ @@ -126,6 +134,7 @@ DLEXT_BITS_FLOAT_ARRAY(Positions, X_FLOAT) DLEXT_BITS_FLOAT_ARRAY(Velocities, V_FLOAT) DLEXT_BITS_FLOAT_ARRAY(Masses, LMP_FLOAT) DLEXT_BITS_FLOAT_ARRAY(Forces, F_FLOAT) +DLEXT_BITS_FLOAT_ARRAY(Virial, F_FLOAT) #undef DLEXT_BITS_FLOAT_ARRAY @@ -151,47 +160,60 @@ inline DLDataType dtype(DLDeviceType device_type, Property p) } template -inline int64_t size(const LAMMPSView& view, Property) +inline int64_t size(const Fix* fix, Property) { - return view.local_particle_number(); + return fix->view.local_particle_number(); } -inline int64_t size(const LAMMPSView& view, Masses) { return view.atom_ptr()->ntypes + 1; } -inline int64_t size(const LAMMPSView& view, TagsMap) { return view.atom_ptr()->get_map_size(); } +inline int64_t size(const Fix* fix, Masses) { return fix->view.atom_ptr()->ntypes + 1; } +inline int64_t size(const Fix* fix, TagsMap) { return fix->view.atom_ptr()->get_map_size(); } +inline int64_t size(const Fix* fix, Virial) { return 6; } template -inline int64_t size(const LAMMPSView& view, Property, SecondDim) +inline int64_t size(const Fix* fix, Property, SecondDim) { return 1; } -inline int64_t size(const LAMMPSView& view, Positions, SecondDim) { return 3; } -inline int64_t size(const LAMMPSView& view, Velocities, SecondDim) { return 3; } -inline int64_t size(const LAMMPSView& view, Forces, SecondDim) { return 3; } +inline int64_t size(const Fix* fix, Positions, SecondDim) { return 3; } +inline int64_t size(const Fix* fix, Velocities, SecondDim) { return 3; } +inline int64_t size(const Fix* fix, Forces, SecondDim) { return 3; } template -constexpr uint64_t offset(const LAMMPSView& view, Property p) +constexpr uint64_t offset(const Fix* fix, Property p) { return 0; } +// a templated function for wrapping a C array given its data type and dimensions +// and returning a pointer to a DLPack tensor template -DLManagedTensor* wrap(const LAMMPSView& view, Property property, ExecutionSpace exec_space) +DLManagedTensor* wrap(const Fix* fix, Property property, ExecutionSpace exec_space) { + // get the device type of the fix (host or device) + auto device_type = fix->view.device_type(exec_space); + auto bridge = std::make_unique(); bridge->tensor.manager_ctx = bridge.get(); bridge->tensor.deleter = delete_bridge; + // acquire the actual dltensor pointer auto& dltensor = bridge->tensor.dl_tensor; - auto device_type = view.device_type(exec_space); - dltensor.data = opaque(view, device_type, property); - dltensor.device = device_info(view, device_type); + + // fill in the dltensor struct + // get the underlying array/accessor of the given property and assign it to data (as void*) + dltensor.data = opaque(fix, device_type, property); + // get the device info from fix and device_type and assign it to device (as DLDevice) + dltensor.device = device_info(fix, device_type); + // get the data type of the underlying array (DLDataType) given the data type code and number of bits dltensor.dtype = dtype(device_type, property); + // fill in the tensor shape (dimensions), strides and byte offsets auto& shape = bridge->shape; - auto size2 = size(view, property, kSecondDim); - shape.push_back(size(view, property)); + auto size2 = size(fix, property, kSecondDim); + shape.push_back(size(fix, property)); + // if the array is 2D if (size2 > 1) shape.push_back(size2); - + // strides between consecutive elements in each dim auto& strides = bridge->strides; strides.push_back(size2); if (size2 > 1) @@ -200,25 +222,28 @@ DLManagedTensor* wrap(const LAMMPSView& view, Property property, ExecutionSpace dltensor.ndim = shape.size(); dltensor.shape = reinterpret_cast(shape.data()); dltensor.strides = reinterpret_cast(strides.data()); - dltensor.byte_offset = offset(view, property); + dltensor.byte_offset = offset(fix, property); return &(bridge.release()->tensor); } -#define DLEXT_PROPERTY_FROM_VIEW(FN, SELECTOR) \ - inline DLManagedTensor* FN(const LAMMPSView& view, ExecutionSpace space) \ +// macro that returns a DLManagedTensor from fix for a given SELECTOR (Property) +#define DLEXT_PROPERTY_FROM_FIX(FN, SELECTOR) \ + inline DLManagedTensor* FN(const Fix* fix, ExecutionSpace space) \ { \ - return wrap(view, SELECTOR, space); \ + return wrap(fix, SELECTOR, space); \ } -DLEXT_PROPERTY_FROM_VIEW(positions, kPositions) -DLEXT_PROPERTY_FROM_VIEW(velocities, kVelocities) -DLEXT_PROPERTY_FROM_VIEW(masses, kMasses) -DLEXT_PROPERTY_FROM_VIEW(forces, kForces) -DLEXT_PROPERTY_FROM_VIEW(images, kImages) -DLEXT_PROPERTY_FROM_VIEW(tags, kTags) -DLEXT_PROPERTY_FROM_VIEW(tags_map, kTagsMap) -DLEXT_PROPERTY_FROM_VIEW(types, kTypes) +// finally, all the function instances to pack arrays into DLManagedTensor structs +DLEXT_PROPERTY_FROM_FIX(positions, kPositions) +DLEXT_PROPERTY_FROM_FIX(velocities, kVelocities) +DLEXT_PROPERTY_FROM_FIX(masses, kMasses) +DLEXT_PROPERTY_FROM_FIX(forces, kForces) +DLEXT_PROPERTY_FROM_FIX(images, kImages) +DLEXT_PROPERTY_FROM_FIX(tags, kTags) +DLEXT_PROPERTY_FROM_FIX(tags_map, kTagsMap) +DLEXT_PROPERTY_FROM_FIX(types, kTypes) +DLEXT_PROPERTY_FROM_FIX(virial, kVirial) #undef DLEXT_PROPERTY diff --git a/dlext/include/FixDLExt.h b/dlext/include/FixDLExt.h index 8d8f048..7af14a6 100644 --- a/dlext/include/FixDLExt.h +++ b/dlext/include/FixDLExt.h @@ -4,11 +4,10 @@ #ifndef DLEXT_SAMPLER_H_ #define DLEXT_SAMPLER_H_ -#include "LAMMPSView.h" - +#include "dlpack/dlpack.h" #include "fix.h" - #include +#include "LAMMPSView.h" namespace LAMMPS_NS { @@ -19,6 +18,7 @@ namespace dlext using TimeStep = bigint; // bigint depends on how LAMMPS was built using DLExtCallback = std::function; +using DLExtSetVirial = std::function; // } // Aliases @@ -39,9 +39,15 @@ class DEFAULT_VISIBILITY FixDLExt : public Fix { int setmask() override; void post_force(int) override; void set_callback(DLExtCallback& cb); + void set_virial_callback(DLExtSetVirial& cb); + void set_virial_global(int flag) { virial_global_flag = flag; } + void set_view(LAMMPSView _view); + LAMMPSView get_view() const; -private: +protected: DLExtCallback callback = [](TimeStep) { }; + DLExtSetVirial setVirial = [](double*) { }; + LAMMPSView view; }; void register_FixDLExt(LAMMPS* lmp); diff --git a/dlext/src/FixDLExt.cpp b/dlext/src/FixDLExt.cpp index d57f35b..83cb157 100644 --- a/dlext/src/FixDLExt.cpp +++ b/dlext/src/FixDLExt.cpp @@ -32,17 +32,54 @@ FixDLExt::FixDLExt(LAMMPS* lmp, int narg, char** arg) if (atom->map_style != Atom::MAP_ARRAY) error->all(FLERR, "Fix dlext requires to map atoms as arrays"); + // signal that this fix contributes to the global virial or not, default no + virial_global_flag = 0; + kokkosable = has_kokkos_cuda_enabled(lmp); atomKK = dynamic_cast(atom); execution_space = (on_host || !kokkosable) ? kOnHost : kOnDevice; datamask_read = EMPTY_MASK; datamask_modify = EMPTY_MASK; + + // create an instance of LAMMPSView + // using default copy constructor and operator '=' of LAMMPSView works here + // because LAMMPSView simply encapsulates the LAMMPS pointers from lmp + view = LAMMPSView(lmp); +} + +int FixDLExt::setmask() +{ + return FixConst::POST_FORCE; } -int FixDLExt::setmask() { return FixConst::POST_FORCE; } -void FixDLExt::post_force(int) { callback(update->ntimestep); } +void FixDLExt::post_force(int vflag) +{ + // virial setup + + v_init(vflag); + + // invoke callback + + callback(update->ntimestep); + + // put the virial from the bias into this fix's member variable virial[6] (see fix.h) + + if (virial_global_flag) + setVirial(virial); +} + +// callback from the sampling method to add the biasing forces to the atoms void FixDLExt::set_callback(DLExtCallback& cb) { callback = cb; } +// callback from the sampling method to set the virial contribution to the fix's virial +void FixDLExt::set_virial_callback(DLExtSetVirial& cb) { setVirial = cb; } + +// set the LAMMPSView object +void FixDLExt::set_view(LAMMPSView _view) { view = _view; } + +// get the LAMMPSView object +LAMMPSView FixDLExt::get_view() const { return view; } + void register_FixDLExt(LAMMPS* lmp) { auto fix_map = lmp->modify->fix_map; diff --git a/python/PyDLExt.h b/python/PyDLExt.h index 9618398..f5a278d 100644 --- a/python/PyDLExt.h +++ b/python/PyDLExt.h @@ -22,7 +22,7 @@ const char* const kDLTensorCapsuleName = "dltensor"; // See the DLPack Documentation https://dmlc.github.io/dlpack/latest/python_spec.html template -inline PyCapsule enpycapsulate(const LAMMPSView& view, ExecutionSpace space) +inline PyCapsule enpycapsulate(const LAMMPS& view, ExecutionSpace space) { auto dl_managed_tensor = property(view, space); return PyCapsule( diff --git a/python/lammps_dlext.cpp b/python/lammps_dlext.cpp index 5929212..fff9b41 100644 --- a/python/lammps_dlext.cpp +++ b/python/lammps_dlext.cpp @@ -49,6 +49,8 @@ void export_FixDLExt(py::module& m) return static_cast(fix); })) .def("set_callback", &FixDLExt::set_callback) + .def("set_virial_callback", &FixDLExt::set_virial_callback) + .def_property("view", &FixDLExt::get_view, &FixDLExt::set_view); ; }