Skip to content

Commit

Permalink
[src] Make sure component is tested (failing. need lower-level tests.)
Browse files Browse the repository at this point in the history
  • Loading branch information
danpovey committed Jul 3, 2017
1 parent 739ff5c commit 54aa27e
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/nnet3/nnet-attention-component.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ std::string RestrictedAttentionComponent::Info() const {
stream << Type() << ", input-dim=" << InputDim()
<< ", output-dim=" << OutputDim()
<< ", num-heads=" << num_heads_
<< ", time-stride=" << time_stride_
<< ", key-dim=" << key_dim_
<< ", value-dim=" << value_dim_
<< ", num-left-inputs=" << num_left_inputs_
Expand Down Expand Up @@ -486,7 +487,7 @@ void RestrictedAttentionComponent::GetInputIndexes(
desired_indexes->resize(context_dim_);
int32 n = output_index.n, x = output_index.x,
i = 0;
for (int32 t = first_time, i = 0; t <= last_time; t += time_stride_, i++) {
for (int32 t = first_time; t <= last_time; t += time_stride_, i++) {
(*desired_indexes)[i].n = n;
(*desired_indexes)[i].t = t;
(*desired_indexes)[i].x = x;
Expand Down
6 changes: 6 additions & 0 deletions src/nnet3/nnet-component-itf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
#include "nnet3/nnet-simple-component.h"
#include "nnet3/nnet-general-component.h"
#include "nnet3/nnet-convolutional-component.h"
#include "nnet3/nnet-attention-component.h"
#include "nnet3/nnet-parse.h"
#include "nnet3/nnet-computation-graph.h"



// \file This file contains some more-generic component code: things in base classes.
// See nnet-component.cc for the code of the actual Components.

Expand Down Expand Up @@ -61,6 +63,8 @@ ComponentPrecomputedIndexes* ComponentPrecomputedIndexes::NewComponentPrecompute
ans = new BackpropTruncationComponentPrecomputedIndexes();
} else if (cpi_type == "TimeHeightConvolutionComponentPrecomputedIndexes") {
ans = new TimeHeightConvolutionComponent::PrecomputedIndexes();
} else if (cpi_type == "RestrictedAttentionComponentPrecomputedIndexes") {
ans = new RestrictedAttentionComponent::PrecomputedIndexes();
}
if (ans != NULL) {
KALDI_ASSERT(cpi_type == ans->Type());
Expand Down Expand Up @@ -159,6 +163,8 @@ Component* Component::NewComponentOfType(const std::string &component_type) {
ans = new BatchNormComponent();
} else if (component_type == "TimeHeightConvolutionComponent") {
ans = new TimeHeightConvolutionComponent();
} else if (component_type == "RestrictedAttentionComponent") {
ans = new RestrictedAttentionComponent();
} else if (component_type == "SumBlockComponent") {
ans = new SumBlockComponent();
}
Expand Down
64 changes: 63 additions & 1 deletion src/nnet3/nnet-test-utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,63 @@ void GenerateConfigSequenceCnnNew(
}



void GenerateConfigSequenceRestrictedAttention(
const NnetGenerationOptions &opts,
std::vector<std::string> *configs) {
std::ostringstream ss;


int32 input_dim = RandInt(100, 150),
num_heads = RandInt(1, 2),
key_dim = RandInt(20, 40),
value_dim = RandInt(20, 40),
time_stride = RandInt(1, 3),
num_left_inputs = RandInt(1, 4),
num_right_inputs = RandInt(0, 2),
num_left_inputs_required = RandInt(0, num_left_inputs),
num_right_inputs_required = RandInt(0, num_right_inputs);
bool output_context = (RandInt(0, 1) == 0);
int32 context_dim = (num_left_inputs + 1 + num_right_inputs),
query_dim = key_dim + context_dim;
int32 attention_input_dim = num_heads * (key_dim + value_dim + query_dim);

std::string cur_layer_descriptor = "input";

{ // input layer.
ss << "input-node name=input dim=" << input_dim
<< std::endl;
}

{ // affine component
ss << "component name=affine type=NaturalGradientAffineComponent input-dim="
<< input_dim << " output-dim=" << attention_input_dim << std::endl;
ss << "component-node name=affine component=affine input=input"
<< std::endl;
}

{ // attention component
ss << "component-node name=attention component=attention input=affine"
<< std::endl;
ss << "component name=attention type=RestrictedAttentionComponent"
<< " num-heads=" << num_heads << " key-dim=" << key_dim
<< " value-dim=" << value_dim << " time-stride=" << time_stride
<< " num-left-inputs=" << num_left_inputs << " num-right-inputs="
<< num_right_inputs << " num-left-inputs-required="
<< num_left_inputs_required << " num-right-inputs-required="
<< num_right_inputs_required
<< " output-context=" << (output_context ? "true" : "false")
<< (RandInt(0, 1) == 0 ? " key-scale=1.0" : "")
<< std::endl;
}

{ // output
ss << "output-node name=output input=attention" << std::endl;
}
configs->push_back(ss.str());
}


// generates a config sequence involving DistributeComponent.
void GenerateConfigSequenceDistribute(
const NnetGenerationOptions &opts,
Expand Down Expand Up @@ -1212,11 +1269,16 @@ void GenerateConfigSequence(
// We're allocating more case statements to the most recently
// added type of model, to give more thorough testing where
// it's needed most.
case 12: case 13: case 14:
case 12:
if (!opts.allow_nonlinearity || !opts.allow_context)
goto start;
GenerateConfigSequenceCnnNew(opts, configs);
break;
case 13: case 14:
if (!opts.allow_nonlinearity || !opts.allow_context)
goto start;
GenerateConfigSequenceRestrictedAttention(opts, configs);
break;
default:
KALDI_ERR << "Error generating config sequence.";
}
Expand Down

0 comments on commit 54aa27e

Please sign in to comment.