Skip to content

Commit

Permalink
solver: check and set type to reconcile class and proto
Browse files Browse the repository at this point in the history
the solver checks its proto type (SolverParameter.type) on
instantiation:

- if the proto type is unspecified it's set according to the class type
  `Solver::type()`
- if the proto type and class type conflict, the solver dies loudly

this helps avoid accidental instantiation of a different solver type
than intended when the solver def and class differ. guaranteed type
information in the SolverParameter will simplify multi-solver
coordination too.
  • Loading branch information
shelhamer committed Nov 21, 2016
1 parent 473f143 commit e52451d
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 0 deletions.
2 changes: 2 additions & 0 deletions include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ class Solver {
virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
void DisplayOutputBlobs(const int net_id);
void UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss);
/// Harmonize solver class type with configured proto type.
void CheckType(SolverParameter* param);

SolverParameter param_;
int iter_;
Expand Down
12 changes: 12 additions & 0 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,21 @@ Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver)
requested_early_exit_(false) {
SolverParameter param;
ReadSolverParamsFromTextFileOrDie(param_file, &param);
CheckType(&param);
Init(param);
}

template <typename Dtype>
void Solver<Dtype>::CheckType(SolverParameter* param) {
// Harmonize solver class type with configured type to avoid confusion.
if (param->has_type()) {
CHECK_EQ(param->type(), this->type())
<< "Solver type must agree with instantiated solver class.";
} else {
param->set_type(this->type());
}
}

template <typename Dtype>
void Solver<Dtype>::Init(const SolverParameter& param) {
CHECK(Caffe::root_solver() || root_solver_)
Expand Down
5 changes: 5 additions & 0 deletions src/caffe/test/test_gradient_based_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,11 @@ TYPED_TEST(SGDSolverTest, TestSnapshotShare) {
}
}

TYPED_TEST(SGDSolverTest, TestSolverType) {
this->TestLeastSquaresUpdate();
EXPECT_NE(this->solver_->type(), string(""));
EXPECT_EQ(this->solver_->type(), this->solver_->param().type());
}

template <typename TypeParam>
class AdaGradSolverTest : public GradientBasedSolverTest<TypeParam> {
Expand Down

1 comment on commit e52451d

@wadefelix
Copy link

@wadefelix wadefelix commented on e52451d Nov 24, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I used pycaffe , If I specify the solver type in the solver.prototxt, CheckType failed.

virtual function type() shouldn't be called in constructors.
http://stackoverflow.com/questions/962132/calling-virtual-functions-inside-constructors

Please sign in to comment.