Skip to content

Commit

Permalink
pdlp: export from google3
Browse files Browse the repository at this point in the history
  • Loading branch information
Mizux committed Sep 30, 2024
1 parent b1c46fa commit e34c9ee
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 64 deletions.
7 changes: 4 additions & 3 deletions ortools/pdlp/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,10 @@ cc_library(
hdrs = ["sharded_quadratic_program.h"],
deps = [
":quadratic_program",
":scheduler",
":sharder",
":solvers_cc_proto",
"//ortools/base",
"//ortools/base:threadpool",
"//ortools/util:logging",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
Expand All @@ -301,9 +302,9 @@ cc_library(
srcs = ["sharder.cc"],
hdrs = ["sharder.h"],
deps = [
":scheduler",
"//ortools/base",
"//ortools/base:mathutil",
"//ortools/base:threadpool",
"//ortools/base:timer",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
Expand All @@ -317,10 +318,10 @@ cc_test(
srcs = ["sharder_test.cc"],
deps = [
":gtest_main",
":scheduler",
":sharder",
"//ortools/base",
"//ortools/base:mathutil",
"//ortools/base:threadpool",
"@com_google_absl//absl/random:distributions",
"@eigen//:eigen3",
],
Expand Down
3 changes: 2 additions & 1 deletion ortools/pdlp/primal_dual_hybrid_gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,8 @@ PreprocessSolver::PreprocessSolver(QuadraticProgram qp,
: num_threads_(
NumThreads(params.num_threads(), params.num_shards(), qp, *logger)),
num_shards_(NumShards(num_threads_, params.num_shards())),
sharded_qp_(std::move(qp), num_threads_, num_shards_),
sharded_qp_(std::move(qp), num_threads_, num_shards_,
params.scheduler_type(), nullptr),
logger_(*logger) {}

SolverResult ErrorSolverResult(const TerminationReason reason,
Expand Down
20 changes: 10 additions & 10 deletions ortools/pdlp/sharded_quadratic_program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "ortools/pdlp/sharded_quadratic_program.h"

#include <cstdint>
#include <limits>
#include <memory>
#include <optional>
#include <utility>
Expand All @@ -23,9 +24,10 @@
#include "absl/log/check.h"
#include "absl/strings/string_view.h"
#include "ortools/base/logging.h"
#include "ortools/base/threadpool.h"
#include "ortools/pdlp/quadratic_program.h"
#include "ortools/pdlp/scheduler.h"
#include "ortools/pdlp/sharder.h"
#include "ortools/pdlp/solvers.pb.h"
#include "ortools/util/logging.h"

namespace operations_research::pdlp {
Expand Down Expand Up @@ -76,24 +78,22 @@ void WarnIfMatrixUnbalanced(

ShardedQuadraticProgram::ShardedQuadraticProgram(
QuadraticProgram qp, const int num_threads, const int num_shards,
operations_research::SolverLogger* logger)
SchedulerType scheduler_type, operations_research::SolverLogger* logger)
: qp_(std::move(qp)),
transposed_constraint_matrix_(qp_.constraint_matrix.transpose()),
thread_pool_(num_threads == 1
? nullptr
: std::make_unique<ThreadPool>("PDLP", num_threads)),
scheduler_(num_threads == 1 ? nullptr
: MakeScheduler(scheduler_type, num_threads)),
constraint_matrix_sharder_(qp_.constraint_matrix, num_shards,
thread_pool_.get()),
scheduler_.get()),
transposed_constraint_matrix_sharder_(transposed_constraint_matrix_,
num_shards, thread_pool_.get()),
num_shards, scheduler_.get()),
primal_sharder_(qp_.variable_lower_bounds.size(), num_shards,
thread_pool_.get()),
scheduler_.get()),
dual_sharder_(qp_.constraint_lower_bounds.size(), num_shards,
thread_pool_.get()) {
scheduler_.get()) {
CHECK_GE(num_threads, 1);
CHECK_GE(num_shards, num_threads);
if (num_threads > 1) {
thread_pool_->StartWorkers();
const int64_t work_per_iteration = qp_.constraint_matrix.nonZeros() +
qp_.variable_lower_bounds.size() +
qp_.constraint_lower_bounds.size();
Expand Down
14 changes: 8 additions & 6 deletions ortools/pdlp/sharded_quadratic_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@
#include <cstdint>
#include <memory>
#include <optional>
#include <utility>

#include "Eigen/Core"
#include "Eigen/SparseCore"
#include "ortools/base/threadpool.h"
#include "ortools/pdlp/quadratic_program.h"
#include "ortools/pdlp/scheduler.h"
#include "ortools/pdlp/sharder.h"
#include "ortools/pdlp/solvers.pb.h"
#include "ortools/util/logging.h"

namespace operations_research::pdlp {

// This class stores:
// - A `QuadraticProgram` (QP)
// - A transposed version of the QP's constraint matrix
// - A thread pool
// - A thread scheduler
// - Various `Sharder` objects for doing sharded matrix and vector
// computations.
class ShardedQuadraticProgram {
Expand All @@ -40,8 +40,10 @@ class ShardedQuadraticProgram {
// Note that the `qp` is intentionally passed by value.
// If `logger` is not nullptr, warns about unbalanced matrices using it;
// otherwise warns via Google standard logging.
ShardedQuadraticProgram(QuadraticProgram qp, int num_threads, int num_shards,
operations_research::SolverLogger* logger = nullptr);
ShardedQuadraticProgram(
QuadraticProgram qp, int num_threads, int num_shards,
SchedulerType scheduler_type = SCHEDULER_TYPE_GOOGLE_THREADPOOL,
operations_research::SolverLogger* logger = nullptr);

// Movable but not copyable.
ShardedQuadraticProgram(const ShardedQuadraticProgram&) = delete;
Expand Down Expand Up @@ -114,7 +116,7 @@ class ShardedQuadraticProgram {
QuadraticProgram qp_;
Eigen::SparseMatrix<double, Eigen::ColMajor, int64_t>
transposed_constraint_matrix_;
std::unique_ptr<ThreadPool> thread_pool_;
std::unique_ptr<Scheduler> scheduler_;
Sharder constraint_matrix_sharder_;
Sharder transposed_constraint_matrix_sharder_;
Sharder primal_sharder_;
Expand Down
52 changes: 24 additions & 28 deletions ortools/pdlp/sharder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@
#include "absl/time/time.h"
#include "ortools/base/logging.h"
#include "ortools/base/mathutil.h"
#include "ortools/base/threadpool.h"
#include "ortools/base/timer.h"
#include "ortools/pdlp/scheduler.h"

namespace operations_research::pdlp {

using ::Eigen::VectorXd;

Sharder::Sharder(const int64_t num_elements, const int num_shards,
ThreadPool* const thread_pool,
Scheduler* const scheduler,
const std::function<int64_t(int64_t)>& element_mass)
: thread_pool_(thread_pool) {
: scheduler_(scheduler) {
CHECK_GE(num_elements, 0);
if (num_elements == 0) {
shard_starts_.push_back(0);
Expand Down Expand Up @@ -70,8 +70,8 @@ Sharder::Sharder(const int64_t num_elements, const int num_shards,
}

Sharder::Sharder(const int64_t num_elements, const int num_shards,
ThreadPool* const thread_pool)
: thread_pool_(thread_pool) {
Scheduler* const scheduler)
: scheduler_(scheduler) {
CHECK_GE(num_elements, 0);
if (num_elements == 0) {
shard_starts_.push_back(0);
Expand Down Expand Up @@ -104,34 +104,30 @@ Sharder::Sharder(const Sharder& other_sharder, const int64_t num_elements)
// The `std::max()` protects against `other_sharder.NumShards() == 0`, which
// will happen if `other_sharder` had `num_elements == 0`.
: Sharder(num_elements, std::max(1, other_sharder.NumShards()),
other_sharder.thread_pool_) {}
other_sharder.scheduler_) {}

void Sharder::ParallelForEachShard(
const std::function<void(const Shard&)>& func) const {
if (thread_pool_) {
if (scheduler_) {
absl::BlockingCounter counter(NumShards());
VLOG(2) << "Starting ParallelForEachShard()";
for (int shard_num = 0; shard_num < NumShards(); ++shard_num) {
thread_pool_->Schedule([&, shard_num]() {
WallTimer timer;
if (VLOG_IS_ON(2)) {
timer.Start();
}
func(Shard(shard_num, this));
if (VLOG_IS_ON(2)) {
timer.Stop();
VLOG(2) << "Shard " << shard_num << " with " << ShardSize(shard_num)
<< " elements and " << ShardMass(shard_num)
<< " mass finished with "
<< ShardMass(shard_num) /
std::max(int64_t{1}, absl::ToInt64Microseconds(
timer.GetDuration()))
<< " mass/usec.";
}
counter.DecrementCount();
});
}
counter.Wait();
scheduler_->ParallelFor(0, NumShards(), [&](int shard_num) {
WallTimer timer;
if (VLOG_IS_ON(2)) {
timer.Start();
}
func(Shard(shard_num, this));
if (VLOG_IS_ON(2)) {
timer.Stop();
VLOG(2) << "Shard " << shard_num << " with " << ShardSize(shard_num)
<< " elements and " << ShardMass(shard_num)
<< " mass finished with "
<< ShardMass(shard_num) /
std::max(int64_t{1},
absl::ToInt64Microseconds(timer.GetDuration()))
<< " mass/usec.";
}
});
VLOG(2) << "Done ParallelForEachShard()";
} else {
for (int shard_num = 0; shard_num < NumShards(); ++shard_num) {
Expand Down
18 changes: 9 additions & 9 deletions ortools/pdlp/sharder.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include "Eigen/Core"
#include "Eigen/SparseCore"
#include "absl/log/check.h"
#include "ortools/base/threadpool.h"
#include "ortools/pdlp/scheduler.h"

namespace operations_research::pdlp {

Expand Down Expand Up @@ -141,26 +141,26 @@ class Sharder {
// Creates a `Sharder` for problems with `num_elements` elements and mass of
// each element given by `element_mass`. Each shard will have roughly the same
// mass. The number of shards in the resulting `Sharder` will be approximately
// `num_shards` but may differ. The `thread_pool` will be used for parallel
// operations executed by e.g. `ParallelForEachShard()`. The `thread_pool` may
// `num_shards` but may differ. The `scheduler` will be used for parallel
// operations executed by e.g. `ParallelForEachShard()`. The `scheduler` may
// be nullptr, which means work will be executed in the same thread. If
// `thread_pool` is not nullptr, the underlying object is not owned and must
// `scheduler` is not nullptr, the underlying object is not owned and must
// outlive the `Sharder`.
Sharder(int64_t num_elements, int num_shards, ThreadPool* thread_pool,
Sharder(int64_t num_elements, int num_shards, Scheduler* scheduler,
const std::function<int64_t(int64_t)>& element_mass);

// Creates a `Sharder` for problems with `num_elements` elements and unit
// mass. This constructor exploits having all element mass equal to 1 to take
// time proportional to `num_shards` instead of `num_elements`. Also see the
// comments above the first constructor.
Sharder(int64_t num_elements, int num_shards, ThreadPool* thread_pool);
Sharder(int64_t num_elements, int num_shards, Scheduler* scheduler);

// Creates a `Sharder` for processing `matrix`. The elements correspond to
// columns of `matrix` and have mass linear in the number of non-zeros. Also
// see the comments above the first constructor.
Sharder(const Eigen::SparseMatrix<double, Eigen::ColMajor, int64_t>& matrix,
int num_shards, ThreadPool* thread_pool)
: Sharder(matrix.cols(), num_shards, thread_pool, [&matrix](int64_t col) {
int num_shards, Scheduler* scheduler)
: Sharder(matrix.cols(), num_shards, scheduler, [&matrix](int64_t col) {
return 1 + 1 * matrix.col(col).nonZeros();
}) {}

Expand Down Expand Up @@ -227,7 +227,7 @@ class Sharder {
// Size: `NumShards()`. The mass of each shard.
std::vector<int64_t> shard_masses_;
// NOT owned. May be nullptr.
ThreadPool* thread_pool_;
Scheduler* scheduler_;
};

// Like `matrix.transpose() * vector` but executed in parallel using `sharder`.
Expand Down
12 changes: 5 additions & 7 deletions ortools/pdlp/sharder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#include "ortools/base/gmock.h"
#include "ortools/base/logging.h"
#include "ortools/base/mathutil.h"
#include "ortools/base/threadpool.h"
#include "ortools/pdlp/scheduler.h"

namespace operations_research::pdlp {
namespace {
Expand Down Expand Up @@ -434,9 +434,8 @@ TEST_P(VariousSizesTest, LargeMatVec) {
LargeSparseMatrix(size);
const int num_threads = 5;
const int shards_per_thread = 3;
ThreadPool pool("MatrixVectorProductTest", num_threads);
pool.StartWorkers();
Sharder sharder(mat, shards_per_thread * num_threads, &pool);
GoogleThreadPoolScheduler scheduler(num_threads);
Sharder sharder(mat, shards_per_thread * num_threads, &scheduler);
VectorXd rhs = VectorXd::Random(size);
VectorXd direct = mat.transpose() * rhs;
VectorXd threaded = TransposedMatrixVectorProduct(mat, rhs, sharder);
Expand All @@ -446,9 +445,8 @@ TEST_P(VariousSizesTest, LargeMatVec) {
TEST_P(VariousSizesTest, LargeVectors) {
const int64_t size = GetParam();
const int num_threads = 5;
ThreadPool pool("SquaredNormTest", num_threads);
pool.StartWorkers();
Sharder sharder(size, num_threads, &pool);
GoogleThreadPoolScheduler scheduler(num_threads);
Sharder sharder(size, num_threads, &scheduler);
VectorXd vec = VectorXd::Random(size);
const double direct = vec.squaredNorm();
const double threaded = SquaredNorm(vec, sharder);
Expand Down
5 changes: 5 additions & 0 deletions ortools/pdlp/solvers.proto
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,11 @@ message PrimalDualHybridGradientParams {
// Otherwise a default that depends on num_threads will be used.
optional int32 num_shards = 27 [default = 0];

// The type of scheduler used for CPU multi-threading. See the documentation
// of the corresponding enum for more details.
optional SchedulerType scheduler_type = 32
[default = SCHEDULER_TYPE_GOOGLE_THREADPOOL];

// If true, the iteration_stats field of the SolveLog output will be populated
// at every iteration. Note that we only compute solution statistics at
// termination checks. Setting this parameter to true may substantially
Expand Down

0 comments on commit e34c9ee

Please sign in to comment.