Skip to content

Commit

Permalink
dropout_schedule: Add set-dropout-proportion in nnet3 utils
Browse files Browse the repository at this point in the history
  • Loading branch information
vimalmanohar committed Dec 6, 2016
1 parent ca5bdf9 commit d055533
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 13 deletions.
33 changes: 23 additions & 10 deletions src/nnet3/nnet-utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -523,16 +523,6 @@ std::string NnetInfo(const Nnet &nnet) {
return ostr.str();
}

void SetDropoutProportion(BaseFloat dropout_proportion,
Nnet *nnet) {
for (int32 c = 0; c < nnet->NumComponents(); c++) {
Component *comp = nnet->GetComponent(c);
DropoutComponent *dc = dynamic_cast<DropoutComponent*>(comp);
if (dc != NULL)
dc->SetDropoutProportion(dropout_proportion);
}
}

void FindOrphanComponents(const Nnet &nnet, std::vector<int32> *components) {
int32 num_components = nnet.NumComponents(), num_nodes = nnet.NumNodes();
std::vector<bool> is_used(num_components, false);
Expand Down Expand Up @@ -688,6 +678,29 @@ void ReadEditConfig(std::istream &edit_config_is, Nnet *nnet) {
if (outputs_remaining == 0)
KALDI_ERR << "All outputs were removed.";
nnet->RemoveSomeNodes(nodes_to_remove);
} else if (directive == "set-dropout-proportion") {
std::string name_pattern = "*";
// name_pattern defaults to '*' if none is given. This pattern
// matches names of components, not nodes.
config_line.GetValue("name", &name_pattern);
BaseFloat proportion = -1;
if (!config_line.GetValue("proportion", &proportion)) {
KALDI_ERR << "In edits-config, expected proportion to be set in line: "
<< config_line.WholeLine();
}
DropoutComponent *component = NULL;
int32 num_dropout_proportions_set = 0;
for (int32 c = 0; c < nnet->NumComponents(); c++) {
if (NameMatchesPattern(nnet->GetComponentName(c).c_str(),
name_pattern.c_str()) &&
(component =
dynamic_cast<DropoutComponent*>(nnet->GetComponent(c)))) {
component->SetDropoutProportion(proportion);
num_dropout_proportions_set++;
}
}
KALDI_LOG << "Set dropout proportions for "
<< num_dropout_proportions_set << " nodes.";
} else {
KALDI_ERR << "Directive '" << directive << "' is not currently "
"supported (reading edit-config).";
Expand Down
3 changes: 3 additions & 0 deletions src/nnet3/nnet-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ void FindOrphanNodes(const Nnet &nnet, std::vector<int32> *nodes);
remove internal nodes directly; instead you should use the command
'remove-orphans'.
set-dropout-proportion [name=<name-pattern>] proportion=<dropout-proportion>
Sets the dropout rates for any components of type DropoutComponent whose
names match the given <name-pattern> (e.g. lstm*). <name-pattern> defaults to "*".
\endverbatim
*/
void ReadEditConfig(std::istream &config_file, Nnet *nnet);
Expand Down
12 changes: 9 additions & 3 deletions src/nnet3bin/nnet3-copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ int main(int argc, char *argv[]) {

bool binary_write = true;
BaseFloat learning_rate = -1,
dropout = 0.0;
dropout = -1;
std::string nnet_config, edits_config, edits_str;
BaseFloat scale = 1.0;

Expand All @@ -64,7 +64,10 @@ int main(int argc, char *argv[]) {
"will be converted to newlines before parsing. E.g. "
"'--edits=remove-orphans'.");
po.Register("set-dropout-proportion", &dropout, "Set dropout proportion "
"in all DropoutComponent to this value.");
"in all DropoutComponent to this value. "
"This option is deprecated. Use set-dropout-proportion "
"option in edits-config. See comments in ReadEditConfig() "
"in nnet3/nnet-utils.h.");
po.Register("scale", &scale, "The parameter matrices are scaled"
" by the specified value.");
po.Read(argc, argv);
Expand Down Expand Up @@ -92,7 +95,10 @@ int main(int argc, char *argv[]) {
ScaleNnet(scale, &nnet);

if (dropout > 0)
SetDropoutProportion(dropout, &nnet);
KALDI_ERR << "--dropout option is deprecated. "
<< "Use set-dropout-proportion "
<< "option in edits-config. See comments in ReadEditConfig() "
<< "in nnet3/nnet-utils.h.";

if (!edits_config.empty()) {
Input ki(edits_config);
Expand Down

0 comments on commit d055533

Please sign in to comment.