Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Collision ModelParameters improvements #290

Merged
merged 4 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions Collision/examples/QCD.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,30 +39,30 @@ wallgo::PhysicsModel setupQCD()
However the helper functions are needed later when defining particle content anyway (assuming non-ultrarelativistic particles), see below.
*/

/* The parameter container used by WallGo collision routines is of wallgo::ModelParameters type, which is a wrapper around std::map.
/* The parameter container used by WallGo collision routines is of wallgo::ModelParameters type, which is a wrapper around std::unordered_map.
Here we write our parameter definitions to a ModelParameters variable and pass it to modelDefinitions later. */
wallgo::ModelParameters parameters;

parameters.addOrModifyParameter("gs", 1.2279920495357861);
parameters.add("gs", 1.2279920495357861);

/* Define mass helper functions. We need the mass-squares in units of temperature, ie. m^2 / T^2.
These should take in a wallgo::ModelParameters object and return a double value

Here we use a C++11 lambda expression, with explicit return type, to define the mass function: */
These should take in a wallgo::ModelParameters object and return a double value.
Here we use lambda expressions with an explicit return types to define the mass functions: */
auto quarkThermalMassSquared = [](const wallgo::ModelParameters& params) -> double
{
const double gs = params.getParameterValue("gs");
// Read-only access to ModelParameters is through the at() function. operator[] is for write access, so cannot be used here
const double gs = params.at("gs");
return gs * gs / 6.0;
};

auto gluonThermalMassSquared = [](const wallgo::ModelParameters& params) -> double
{
const double gs = params.getParameterValue("gs");
const double gs = params.at("gs");
return 2.0 * gs * gs;
};

parameters.addOrModifyParameter("mq2", quarkThermalMassSquared(parameters));
parameters.addOrModifyParameter("mg2", gluonThermalMassSquared(parameters));
parameters.add("mq2", quarkThermalMassSquared(parameters));
parameters.add("mg2", gluonThermalMassSquared(parameters));

modelDefinition.defineParameters(parameters);

Expand Down Expand Up @@ -224,8 +224,8 @@ int main()

// Can also pack the new parameters in a wallgo::ModelParameters object and pass it to the model:
wallgo::ModelParameters changedParams;
changedParams.addOrModifyParameter("gs", 0.5);
changedParams.addOrModifyParameter("msq[1]", 0.3);
changedParams.add("gs", 0.5);
changedParams.add("mg", 0.3);
model.updateParameters(changedParams);

/* We can also request to compute integrals only for a specific off-equilibrium particle pair.
Expand Down
24 changes: 12 additions & 12 deletions Collision/examples/SM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,27 +58,27 @@ wallgo::ModelParameters computeMasses(const wallgo::ModelParameters& actionParam
const double yt = actionParams.at("yt");

// SU3 gluon - Use asymptotic thermal mass (Debye mass divided by 2). See WallGo paper for details
outMsq.addOrModifyParameter("mg2", gs * gs);
outMsq.add("mg2", gs * gs);


// W boson
outMsq.addOrModifyParameter("mw2", 11. / 12. * gw * gw);
outMsq.add("mw2", 11. / 12. * gw * gw);

// U(1) boson
outMsq.addOrModifyParameter("mb2", 11. / 12. * gY * gY);
outMsq.add("mb2", 11. / 12. * gY * gY);

// Generic light quark - The mass is esimated as the asymptotic mass for a SU(3)_c fundamental fermion
outMsq.addOrModifyParameter("mq2", gs * gs / 3.);
outMsq.add("mq2", gs * gs / 3.);

// leptons - The mass is esimated as the asymptotic mass for a SU(2)_L fundamental fermion
outMsq.addOrModifyParameter("ml2", 3. / 16. * gw * gw);
outMsq.add("ml2", 3. / 16. * gw * gw);

// Higgs - The mass is estimated as the one-loop thermal mass
const double mHsqThermal = 1. / 16. * (3 * gw * gw + gY * gY + 8 * lam1H + 4 * yt * yt);

outMsq.addOrModifyParameter("mH2", mHsqThermal);
outMsq.add("mH2", mHsqThermal);
// "Goldstones"- The mass is estimated as the one-loop thermal mass
outMsq.addOrModifyParameter("mG2", mHsqThermal);
outMsq.add("mG2", mHsqThermal);

return outMsq;
}
Expand All @@ -97,11 +97,11 @@ void defineParametersSM(wallgo::ModelDefinition& inOutModelDef)
m_t=172.57
m_H=125.20
*/
params.addOrModifyParameter("gs", 1.21772); // QCD coupling at Z pole
params.addOrModifyParameter("gw", 0.651653); // SU2 coupling
params.addOrModifyParameter("gY", 0.357449); // hypercharge coupling
params.addOrModifyParameter("yt", 1.00995); // top Yukawa
params.addOrModifyParameter("lam1H", 0.129008); // Higgs self quartic
params.add("gs", 1.21772); // QCD coupling at Z pole
params.add("gw", 0.651653); // SU2 coupling
params.add("gY", 0.357449); // hypercharge coupling
params.add("yt", 1.00995); // top Yukawa
params.add("lam1H", 0.129008); // Higgs self quartic

auto massSquares = computeMasses(params);

Expand Down
33 changes: 23 additions & 10 deletions Collision/python/src/BindToPython.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,20 +199,33 @@ PYBIND11_MODULE(WG_PYTHON_MODULE_NAME, m)

py::class_<ModelParameters>(m, "ModelParameters", "Container for physics model-dependent parameters (couplings etc)")
.def(py::init<>())
.def("addOrModifyParameter", &ModelParameters::addOrModifyParameter, "Define a new named parameter or modify value of an existing one.", py::arg("name"), py::arg("value"))
.def("getParameterValue", &ModelParameters::getParameterValue, "Get current value of specified parameter. Returns 0 if the parameter is not found (prefer the contains() method if unsure)", py::arg("name"))
.def("add", &ModelParameters::add, "Define a new named parameter or modify value of an existing one.", py::arg("name"), py::arg("value"))
.def("addOrModifyParameter", &ModelParameters::add, "DEPRECATED. Use ModelParameters.add().", py::arg("name"), py::arg("value"))
.def("contains", &ModelParameters::contains, "Returns True if the specified parameter has been defined, otherwise returns False", py::arg("name"))
.def("remove", &ModelParameters::remove, "Removes the specified parameter. Does nothing if the parameter does not exist.", py::arg("name"))
.def("clear", &ModelParameters::clear, "Empties the parameter container")
.def("getNumParams", &ModelParameters::getNumParams, "Returns number of contained parameters")
.def("size", &ModelParameters::size, "Returns number of contained parameters")
.def("getParameterNames", &ModelParameters::getParameterNames, "Returns list containing names of parameters that have been defined")
// Operator[] on Python side is __getitem__. Bind a helper lambda to achieve this
.def("at",
static_cast<double& (ModelParameters::*)(const std::string&)>(&ModelParameters::at),
"Get reference to the specified parameter. Raises IndexError if the parameter is not found.",
py::arg("name"),
py::return_value_policy::reference_internal)
// Operator [] on Python side is __getitem__ (read) or __setitem__ (write), bind both accordingly
.def("__getitem__",
[](const ModelParameters& self, const std::string& paramName)
{
return self[paramName];
},
"Get current value of specified parameter. Returns 0 if the parameter is not found (prefer the contains() method if unsure)",
py::arg("name")
static_cast<double& (ModelParameters::*)(const std::string&)>(&ModelParameters::at),
"Get reference to the specified parameter. Raises IndexError if the parameter is not found.",
py::arg("name"),
py::return_value_policy::reference_internal)
.def("__setitem__",
[](TModelParameters<double>& self, const std::string& key, const double& value) {
self.add(key, value);
})
.def("getParameterValue",
static_cast<double& (ModelParameters::*)(const std::string&)>(&ModelParameters::at),
"DEPRECATED. Use the operator [] for parameter access.",
py::arg("name"),
py::return_value_policy::reference_internal
);


Expand Down
77 changes: 77 additions & 0 deletions Collision/python/tests/test_ModelParameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import pytest

from WallGoCollision import ModelParameters

def test_addAndGet() -> None:
""""""
params = ModelParameters()

assert params.size() == 0

params.add("p1", 1.0)

assert params.size() == 1
assert params.contains("p1")
assert params["p1"] == 1.0

params.add("p2", 2.0)
assert params.size() == 2
assert params.contains("p2")
assert params["p2"] == 2.0

def test_bracketOperator() -> None:
"""Tests that the [] operator can be used in place of add()"""
params = ModelParameters()

params["p1"] = 1.0

assert params.contains("p1")
assert params["p1"] == 1.0


def test_modifyParam() -> None:
""""""
params = ModelParameters()

params.add("p1", 1.0)
# modify existing
params.add("p1", 3.0)

assert params.size() == 1
# these two are equivalent:
assert params["p1"] == 3.0
assert params.at("p1") == 3.0


def test_removeParam() -> None:
""""""
params = ModelParameters()

params.add("p1", 1.0)
params.add("p2", 2.0)

assert params.size() == 2

params.remove("p1")
assert params.size() == 1
assert not params.contains("p1")

params.remove("p2")
assert params.size() == 0
assert not params.contains("p2")

def test_invalidAccess() -> None:
"""Checks that invalid access to a parameter raises IndexError"""
params = ModelParameters()

with pytest.raises(IndexError):
params["dumb"]

with pytest.raises(IndexError):
params.at("dumb")

assert not params.contains("dumb")




2 changes: 1 addition & 1 deletion Collision/src/CollisionIntegral.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ void CollisionIntegral4::handleModelChange(const ModelChangeContext& changeConte
{
if (mModelParameters.contains(name))
{
mModelParameters.addOrModifyParameter(name, newValue);
mModelParameters.add(name, newValue);
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions Collision/src/PhysicsModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void ModelDefinition::defineParameter(const std::string& symbol, double value)
return;
}

mParameters.addOrModifyParameter(symbol, value);
mParameters.add(symbol, value);
}

void ModelDefinition::defineParameters(const ModelParameters& inParams)
Expand Down Expand Up @@ -136,10 +136,10 @@ void PhysicsModel::updateParameter(const std::string& symbol, double value)
std::cerr << "Attempted to update undefined parameter: " << symbol << "\n";
return;
}
mParameters.addOrModifyParameter(symbol, value);
mParameters.add(symbol, value);

ModelChangeContext changeContext;
changeContext.changedParams.addOrModifyParameter(symbol, value);
changeContext.changedParams.add(symbol, value);
changeContext.changedParticles = computeParticleChanges();

notifyModelChange(changeContext);
Expand All @@ -153,7 +153,7 @@ void PhysicsModel::updateParameters(const ModelParameters& newValues)
{
std::cerr << "Attempted to update undefined parameter: " << symbol << "\n";
}
mParameters.addOrModifyParameter(symbol, value);
mParameters.add(symbol, value);
}

ModelChangeContext changeContext;
Expand Down
79 changes: 58 additions & 21 deletions Collision/src/include/WallGo/ModelParameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,58 +15,95 @@ A function is probably better because then CollisionIntegrals can handle computa

/* Holds physics model specific parameters that enter matrix elements.
Wraps around std::unordered_map. */
template<typename T>
template<typename Value_t>
class TModelParameters
{
public:
/* Modifies value of specified parameter. If the parameter has not yet been defined, adds it with the specified value. */
void addOrModifyParameter(const std::string& paramName, T newValue);

/* Returns value of the specified parameter. If the parameter is not found, returns 0 and asserts in debug builds. */
T getParameterValue(const std::string& paramName) const;
/* Adds a new parameter with specified name and value, or modifies an existing parameter. */
void add(const std::string& paramName, const Value_t& newValue);

T at(const std::string& paramName) const { return getParameterValue(paramName); }
[[deprecated("TModelParameters::addOrModifyParameter() is deprecated. Use TModelParameters::add()")]]
void addOrModifyParameter(const std::string& paramName, const Value_t& newValue) { add(paramName, newValue); }

/* Returns value of the specified parameter. If the parameter is not found, returns 0 and asserts in debug builds. */
T operator[](const std::string& paramName) const { return getParameterValue(paramName); }
/* Returns reference to the specified parameter. Throws std::out_of_range if the parameter is not found.
This is similar to std::unordered_map::at() but we throw a more informative error message. */
Value_t& at(const std::string& paramName);

/* Const version of getParameter() */
const Value_t& at(const std::string& paramName) const;

/* Returns reference to the specified parameter. If not found, adds that parameter with a default value.
Can be used as: myParams["key"] = newValue; */
Value_t& operator[](const std::string& paramName) { return mParams[paramName]; }

[[deprecated("TModelParameters::getParameterValue() is deprecated. Use TModelParameters::at()")]]
Value_t& getParameterValue(const std::string& paramName) const { return at(paramName); }

// Removes a parameter if it exists. Note that this will invalidate existing iterators to the object
void remove(const std::string& paramName);

// True if we contain the specified parameter name
bool contains(const std::string& paramName) const { return mParams.count(paramName) > 0; }
// Empties the container
void clear() { mParams.clear(); }
// Get number of parameters in the container
size_t size() const { return mParams.size(); }
uint32_t getNumParams() const { return static_cast<uint32_t>(mParams.size()); }
// Get number of parameters in the container
size_t getNumParams() const { return mParams.size(); }
// Returns array containing names of all known parameters
std::vector<std::string> getParameterNames() const;

// Const access to the underlying map
const std::unordered_map<std::string, T>& getParameterMap() const { return mParams; }
const std::unordered_map<std::string, Value_t>& getParameterMap() const { return mParams; }

private:
std::unordered_map<std::string, T> mParams;
std::unordered_map<std::string, Value_t> mParams;
};

template<typename T>
inline void TModelParameters<T>::addOrModifyParameter(const std::string& paramName, T newValue)
template<typename Value_t>
inline void TModelParameters<Value_t>::add(const std::string& paramName, const Value_t& newValue)
{
mParams[paramName] = newValue;
}

template<typename T>
inline T TModelParameters<T>::getParameterValue(const std::string& paramName) const
template<typename Value_t>
inline Value_t& TModelParameters<Value_t>::at(const std::string& paramName)
{
if (!contains(paramName))
{
assert(false && "Parameter not found");
return static_cast<T>(0);
const std::string errorMsg = "Parameter '" + paramName + "' not found in ModelParameters.\n";
throw std::out_of_range(errorMsg);
}

return mParams[paramName];
}

template<typename Value_t>
inline const Value_t& TModelParameters<Value_t>::at(const std::string& paramName) const
{
auto it = mParams.find(paramName);
if (it == mParams.end())
{
const std::string errorMsg = "Parameter '" + paramName + "' not found in ModelParameters.\n";
throw std::out_of_range(errorMsg);
}
return mParams.at(paramName);

return it->second;
}

template<typename Value_t>
inline void TModelParameters<Value_t>::remove(const std::string& paramName)
{
mParams.erase(paramName);
}

template<typename T>
inline std::vector<std::string> TModelParameters<T>::getParameterNames() const
template<typename Value_t>
inline std::vector<std::string> TModelParameters<Value_t>::getParameterNames() const
{
std::vector<std::string> outNames;
if (getNumParams() > 0) outNames.reserve(getNumParams());

for (const auto& [key, _] : mParams)
{
outNames.push_back(key);
Expand Down
Loading