Skip to content

Commit

Permalink
python api to run solvers by cluster
Browse files Browse the repository at this point in the history
  • Loading branch information
jdetaeye committed Aug 19, 2023
1 parent ee2641f commit 18d1859
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 34 deletions.
8 changes: 4 additions & 4 deletions include/frepple/solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ class SolverCreate : public Solver {
PythonFunction getUserExitOperation() const { return userexit_operation; }

/* Python method for running the solver. */
static PyObject* solve(PyObject*, PyObject*);
static PyObject* solve(PyObject*, PyObject*, PyObject*);

/* Python method for committing the plan changes. */
static PyObject* commit(PyObject*, PyObject*);
Expand Down Expand Up @@ -680,7 +680,7 @@ class SolverCreate : public Solver {
}

private:
typedef vector<deque<Demand*> > classified_demand;
typedef vector<deque<Demand*>> classified_demand;
typedef classified_demand::iterator cluster_iterator;
classified_demand demands_per_cluster;

Expand Down Expand Up @@ -995,7 +995,7 @@ class SolverCreate : public Solver {
Duration hitMaxEarly;

bool hitMaxSize = false;

/* Simplistic flag to trace the costs being considered for alternate
* selection. */
bool logcosts = false; // SET TO TRUE AND RECOMPILE TO ACTIVATE EXTRA
Expand Down Expand Up @@ -1023,7 +1023,7 @@ class SolverCreate : public Solver {
set<const Buffer*, order_buffers> purchase_buffers;

// Structure to maintain dependency tree.
map<const Operation*, pair<unsigned short, Date> > dependency_list;
map<const Operation*, pair<unsigned short, Date>> dependency_list;

// Recursively collect all dependencies.
void populateDependencies(const Operation*);
Expand Down
9 changes: 4 additions & 5 deletions src/forecast/forecast.h
Original file line number Diff line number Diff line change
Expand Up @@ -2603,13 +2603,12 @@ class ForecastSolver : public Solver {
*/
void solve(const Demand*, void* = nullptr);

/* This is the main solver method that will appropriately call the other
* solve methods.
*/
void solve(void* v = nullptr);
/* This is the main solver method. */
void solve(void* v = nullptr) { solve(true, -1); }
void solve(bool includenetting = true, int cluster = -1);

/* Python interface for the solve method. */
static PyObject* PythonSolve(PyObject*, PyObject*);
static PyObject* solve(PyObject*, PyObject*, PyObject*);

virtual const MetaClass& getType() const { return *metadata; }
static const MetaClass* metadata;
Expand Down
50 changes: 31 additions & 19 deletions src/forecast/forecastsolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ int ForecastSolver::initialize() {
x.supportgetattro();
x.supportsetattro();
x.supportcreate(create);
x.addMethod("solve", PythonSolve, METH_VARARGS, "run the solver");
x.addMethod("solve", solve, METH_VARARGS, "run the solver");
metadata->setPythonClass(x);
return x.typeReady();
}
Expand Down Expand Up @@ -180,60 +180,69 @@ void ForecastSolver::solve(const Demand* l, void* v) {
}
}

PyObject* ForecastSolver::PythonSolve(PyObject* self, PyObject* args) {
PyObject* ForecastSolver::solve(PyObject* self, PyObject* args,
PyObject* kwargs) {
static const char* kwlist[] = {"includenetting", "cluster", nullptr};
// Create the command
int* fcst_plus_netting = new int(1);
int ok = PyArg_ParseTuple(args, "|p:solve", fcst_plus_netting);
int fcst_plus_netting = 1;
int cluster = -1;
int ok = PyArg_ParseTupleAndKeywords(args, kwargs, "|pi:solve",
const_cast<char**>(kwlist),
&fcst_plus_netting, &cluster);
if (!ok) return nullptr;

// Free Python interpreter for other threads
Py_BEGIN_ALLOW_THREADS;
try {
static_cast<Solver*>(self)->solve(fcst_plus_netting);
static_cast<ForecastSolver*>(self)->solve(fcst_plus_netting == 1, cluster);
} catch (...) {
Py_BLOCK_THREADS;
PythonType::evalException();
delete fcst_plus_netting;
return nullptr;
}

// Reclaim Python interpreter
Py_END_ALLOW_THREADS;
delete fcst_plus_netting;
return Py_BuildValue("");
}

void ForecastSolver::solve(void* v) {
void ForecastSolver::solve(bool includenetting, int cluster) {
// Switch to lazy cache flushing
auto prevCachePolicy = Cache::instance->setWriteImmediately(false);

// Reset forecastconsumed to 0 and forecastnet to forecasttotal
// Reset forecastconsumed to 0 and forecastnet to forecasttotal.
// When running for a cluster we reset the leafs and propagate.
// When running globablly we can skip the propagation.
for (auto f = Forecast::getForecasts(); f; ++f) {
if (cluster != -1 &&
(!f->isLeaf() || static_cast<Forecast*>(&*f)->getCluster() != cluster))
continue;
auto fcstdata = f->getData();
for (auto& bckt : fcstdata->getBuckets()) {
if (bckt.getValue(*Measures::forecastconsumed))
bckt.removeValue(false, Measures::forecastconsumed);
bckt.removeValue(cluster != -1, Measures::forecastconsumed);
auto fcsttotal = bckt.getValue(*Measures::forecasttotal);
if (bckt.getEnd() < Plan::instance().getCurrent() -
(ForecastSolver::getNetPastDemand()
? ForecastSolver::getNetLate()
: Duration(0L)) ||
!fcsttotal)
bckt.removeValue(false, Measures::forecastnet);
bckt.removeValue(cluster != -1, Measures::forecastnet);
else
bckt.setValue(false, Measures::forecastnet, fcsttotal);
bckt.setValue(cluster != -1, Measures::forecastnet, fcsttotal);
}
}

int fcst_plus_netting = *static_cast<int*>(v);
if (fcst_plus_netting) {
if (includenetting) {
// Time series forecasting for all leaf forecasts
// TODO Assumes that the lowest forecasting level is a leaf forecast.
if (getLogLevel() > 5)
logger << "Start forecasting for leave forecasts" << endl;
logger << "Start forecasting for leaf forecasts" << endl;
for (auto x = Forecast::getForecasts(); x; ++x) {
try {
if (x->getMethods() && x->isLeaf())
if (x->getMethods() && x->isLeaf() &&
(cluster == -1 ||
static_cast<Forecast*>(&*x)->getCluster() != cluster))
solve(static_cast<Forecast*>(&*x), nullptr);
} catch (...) {
logger << "Error: Caught an exception while forecasting '"
Expand All @@ -250,14 +259,16 @@ void ForecastSolver::solve(void* v) {
}
}
if (getLogLevel() > 5)
logger << "End forecasting for leave forecasts" << endl;
logger << "End forecasting for leaf forecasts" << endl;

// Time series forecasting for all middle-out parent forecasts
if (getLogLevel() > 5)
logger << "Start forecasting for parent forecasts" << endl;
for (auto x = Forecast::getForecasts(); x; ++x) {
try {
if (x->getMethods() && !x->isLeaf())
if (x->getMethods() && !x->isLeaf() &&
(cluster == -1 ||
static_cast<Forecast*>(&*x)->getCluster() != cluster))
solve(static_cast<Forecast*>(&*x), nullptr);
} catch (...) {
logger << "Error: Caught an exception while forecasting '"
Expand All @@ -282,7 +293,8 @@ void ForecastSolver::solve(void* v) {
sortedDemandList l;
for (auto& i : Demand::all())
if (i.getType() != *Forecast::metadata &&
i.getType() != *ForecastBucket::metadata)
i.getType() != *ForecastBucket::metadata &&
(cluster == -1 || i.getCluster() != cluster))
l.insert(&i);

// Forecast netting loop
Expand Down
21 changes: 15 additions & 6 deletions src/solver/solverplan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -704,27 +704,34 @@ void SolverCreate::solve(void* v) {
threads.execute();
}

PyObject* SolverCreate::solve(PyObject* self, PyObject* args) {
PyObject* SolverCreate::solve(PyObject* self, PyObject* args,
PyObject* kwargs) {
// Parse the argument
static const char* kwlist[] = {"object", "cluster", nullptr};
PyObject* dem = nullptr;
if (args && !PyArg_ParseTuple(args, "|O:solve", &dem)) return nullptr;
int cluster = -1;
int ok = PyArg_ParseTupleAndKeywords(
args, kwargs, "|Oi:solve", const_cast<char**>(kwlist), &dem, &cluster);
if (dem && !PyObject_TypeCheck(dem, Demand::metadata->pythonClass) &&
!PyObject_TypeCheck(dem, Buffer::metadata->pythonClass)) {
PyErr_SetString(PythonDataException,
"solve(d) argument must be a demand or a buffer");
"object argument must be a demand or a buffer");
return nullptr;
}

// Free Python interpreter for other threads
SolverCreate* sol = static_cast<SolverCreate*>(self);
auto prev_cluster = sol->getCluster();
Py_BEGIN_ALLOW_THREADS;
try {
SolverCreate* sol = static_cast<SolverCreate*>(self);
if (!dem) {
// Complete replan
// Complete replan or cluster replan
sol->setCluster(cluster);
sol->setAutocommit(true);
sol->solve();
} else {
// Incrementally plan a single demand
// Incrementally plan a single demand or buffer
sol->setCluster(-1);
sol->setAutocommit(false);
sol->update_user_exits();
if (PyObject_TypeCheck(dem, Demand::metadata->pythonClass))
Expand All @@ -747,10 +754,12 @@ PyObject* SolverCreate::solve(PyObject* self, PyObject* args) {
} catch (...) {
Py_BLOCK_THREADS;
PythonType::evalException();
sol->setCluster(prev_cluster);
return nullptr;
}
// Reclaim Python interpreter
Py_END_ALLOW_THREADS;
sol->setCluster(prev_cluster);
return Py_BuildValue("");
}

Expand Down

0 comments on commit 18d1859

Please sign in to comment.