diff --git a/src/nnet3/nnet-attention-component.cc b/src/nnet3/nnet-attention-component.cc index 758f5673f33..3d7fb649603 100644 --- a/src/nnet3/nnet-attention-component.cc +++ b/src/nnet3/nnet-attention-component.cc @@ -33,6 +33,7 @@ std::string RestrictedAttentionComponent::Info() const { stream << Type() << ", input-dim=" << InputDim() << ", output-dim=" << OutputDim() << ", num-heads=" << num_heads_ + << ", time-stride=" << time_stride_ << ", key-dim=" << key_dim_ << ", value-dim=" << value_dim_ << ", num-left-inputs=" << num_left_inputs_ @@ -486,7 +487,7 @@ void RestrictedAttentionComponent::GetInputIndexes( desired_indexes->resize(context_dim_); int32 n = output_index.n, x = output_index.x, i = 0; - for (int32 t = first_time, i = 0; t <= last_time; t += time_stride_, i++) { + for (int32 t = first_time; t <= last_time; t += time_stride_, i++) { (*desired_indexes)[i].n = n; (*desired_indexes)[i].t = t; (*desired_indexes)[i].x = x; diff --git a/src/nnet3/nnet-component-itf.cc b/src/nnet3/nnet-component-itf.cc index d462ce890d4..5f1052efc81 100644 --- a/src/nnet3/nnet-component-itf.cc +++ b/src/nnet3/nnet-component-itf.cc @@ -25,10 +25,12 @@ #include "nnet3/nnet-simple-component.h" #include "nnet3/nnet-general-component.h" #include "nnet3/nnet-convolutional-component.h" +#include "nnet3/nnet-attention-component.h" #include "nnet3/nnet-parse.h" #include "nnet3/nnet-computation-graph.h" + // \file This file contains some more-generic component code: things in base classes. // See nnet-component.cc for the code of the actual Components. @@ -61,6 +63,8 @@ ComponentPrecomputedIndexes* ComponentPrecomputedIndexes::NewComponentPrecompute ans = new BackpropTruncationComponentPrecomputedIndexes(); } else if (cpi_type == "TimeHeightConvolutionComponentPrecomputedIndexes") { ans = new TimeHeightConvolutionComponent::PrecomputedIndexes(); + } else if (cpi_type == "RestrictedAttentionComponentPrecomputedIndexes") { + ans = new RestrictedAttentionComponent::PrecomputedIndexes(); } if (ans != NULL) { KALDI_ASSERT(cpi_type == ans->Type()); @@ -159,6 +163,8 @@ Component* Component::NewComponentOfType(const std::string &component_type) { ans = new BatchNormComponent(); } else if (component_type == "TimeHeightConvolutionComponent") { ans = new TimeHeightConvolutionComponent(); + } else if (component_type == "RestrictedAttentionComponent") { + ans = new RestrictedAttentionComponent(); } else if (component_type == "SumBlockComponent") { ans = new SumBlockComponent(); } diff --git a/src/nnet3/nnet-test-utils.cc b/src/nnet3/nnet-test-utils.cc index a138fcacceb..47a040a1789 100644 --- a/src/nnet3/nnet-test-utils.cc +++ b/src/nnet3/nnet-test-utils.cc @@ -1076,6 +1076,63 @@ void GenerateConfigSequenceCnnNew( } + +void GenerateConfigSequenceRestrictedAttention( + const NnetGenerationOptions &opts, + std::vector *configs) { + std::ostringstream ss; + + + int32 input_dim = RandInt(100, 150), + num_heads = RandInt(1, 2), + key_dim = RandInt(20, 40), + value_dim = RandInt(20, 40), + time_stride = RandInt(1, 3), + num_left_inputs = RandInt(1, 4), + num_right_inputs = RandInt(0, 2), + num_left_inputs_required = RandInt(0, num_left_inputs), + num_right_inputs_required = RandInt(0, num_right_inputs); + bool output_context = (RandInt(0, 1) == 0); + int32 context_dim = (num_left_inputs + 1 + num_right_inputs), + query_dim = key_dim + context_dim; + int32 attention_input_dim = num_heads * (key_dim + value_dim + query_dim); + + std::string cur_layer_descriptor = "input"; + + { // input layer. + ss << "input-node name=input dim=" << input_dim + << std::endl; + } + + { // affine component + ss << "component name=affine type=NaturalGradientAffineComponent input-dim=" + << input_dim << " output-dim=" << attention_input_dim << std::endl; + ss << "component-node name=affine component=affine input=input" + << std::endl; + } + + { // attention component + ss << "component-node name=attention component=attention input=affine" + << std::endl; + ss << "component name=attention type=RestrictedAttentionComponent" + << " num-heads=" << num_heads << " key-dim=" << key_dim + << " value-dim=" << value_dim << " time-stride=" << time_stride + << " num-left-inputs=" << num_left_inputs << " num-right-inputs=" + << num_right_inputs << " num-left-inputs-required=" + << num_left_inputs_required << " num-right-inputs-required=" + << num_right_inputs_required + << " output-context=" << (output_context ? "true" : "false") + << (RandInt(0, 1) == 0 ? " key-scale=1.0" : "") + << std::endl; + } + + { // output + ss << "output-node name=output input=attention" << std::endl; + } + configs->push_back(ss.str()); +} + + // generates a config sequence involving DistributeComponent. void GenerateConfigSequenceDistribute( const NnetGenerationOptions &opts, @@ -1212,11 +1269,16 @@ void GenerateConfigSequence( // We're allocating more case statements to the most recently // added type of model, to give more thorough testing where // it's needed most. - case 12: case 13: case 14: + case 12: if (!opts.allow_nonlinearity || !opts.allow_context) goto start; GenerateConfigSequenceCnnNew(opts, configs); break; + case 13: case 14: + if (!opts.allow_nonlinearity || !opts.allow_context) + goto start; + GenerateConfigSequenceRestrictedAttention(opts, configs); + break; default: KALDI_ERR << "Error generating config sequence."; }