diff --git a/src/nnet3/nnet-utils.cc b/src/nnet3/nnet-utils.cc index 803ca98abed..15b9b6fa63d 100644 --- a/src/nnet3/nnet-utils.cc +++ b/src/nnet3/nnet-utils.cc @@ -2,6 +2,7 @@ // Copyright 2015 Johns Hopkins University (author: Daniel Povey) // 2016 Daniel Galvez +// David Snyder // // See ../../COPYING for clarification regarding multiple authors // @@ -20,6 +21,10 @@ #include "nnet3/nnet-utils.h" #include "nnet3/nnet-simple-component.h" +#include "nnet3/nnet-compute.h" +#include "nnet3/nnet-optimize.h" +#include "nnet3/nnet-example.h" + namespace kaldi { namespace nnet3 { @@ -421,6 +426,94 @@ std::string NnetInfo(const Nnet &nnet) { return ostr.str(); } +std::string SpMatrixOutputInfo(const Nnet &nnet) { + std::ostringstream os; + // The output 's' is vectorized SpMatrix. + std::string output_name = "s"; + int32 node_index = nnet.GetNodeIndex(output_name); + if (node_index != -1 && nnet.IsOutputNode(node_index)) { + // Check that the output dim is of the form + // (1/2)*(d+1)*d. + int32 output_dim = nnet.OutputDim(output_name); + int32 d = (0.5) * (1 + sqrt(1 + 8 * output_dim)) - 1; + if (((d + 1) * d) / 2 == output_dim) { + SpMatrix S(d); + Vector s_vec(output_dim); + GetConstantOutput(nnet, output_name, &s_vec); + S.CopyFromVec(s_vec); + Vector s(d); + Matrix P(d, d); + S.Eig(&s, &P); + SortSvd(&s, &P); + os << "Eigenvalues of output 's': " << s; + } + } + return os.str(); +} + +void GetConstantOutput(const Nnet &nnet_const, const std::string &output_name, + Vector *output) { + Nnet nnet(nnet_const); + std::string input_name = "input"; + int32 left_context, + right_context, + input_node_index = nnet.GetNodeIndex(input_name), + output_node_index = nnet.GetNodeIndex(output_name); + if (output_node_index == -1 && !nnet.IsOutputNode(output_node_index)) + KALDI_ERR << "No output node called '" << output_name + << "' in the network."; + if (input_node_index == -1 && nnet.IsInputNode(input_node_index)) + KALDI_ERR << "No input node called '" << input_name + << "' in the network."; + KALDI_ASSERT(output->Dim() == nnet.OutputDim(output_name)); + ComputeSimpleNnetContext(nnet, &left_context, &right_context); + + // It's difficult to get the output of the node + // directly. Instead, we can create some fake input, + // propagate it through the network, and read out the + // output. + CuMatrix cu_feats(left_context + right_context, + nnet.InputDim(input_name)); + Matrix feats(cu_feats); + + ComputationRequest request; + NnetIo nnet_io = NnetIo(input_name, 0, feats); + request.inputs.clear(); + request.outputs.clear(); + request.inputs.resize(1); + request.outputs.resize(1); + request.need_model_derivative = false; + request.store_component_stats = false; + + std::vector output_indexes; + request.inputs[0].name = input_name; + request.inputs[0].indexes = nnet_io.indexes; + request.inputs[0].has_deriv = false; + output_indexes.resize(1); + output_indexes[0].n = 0; + output_indexes[0].t = 0; + request.outputs[0].name = output_name; + request.outputs[0].indexes = output_indexes; + request.outputs[0].has_deriv = false; + + CachingOptimizingCompiler compiler(nnet, NnetOptimizeOptions()); + const NnetComputation *computation = compiler.Compile(request); + NnetComputer computer(NnetComputeOptions(), *computation, + nnet, &nnet); + + // check to see if something went wrong. + if (request.inputs.empty()) + KALDI_ERR << "No input in computation request."; + if (request.outputs.empty()) + KALDI_ERR << "No output in computation request."; + + computer.AcceptInput("input", &cu_feats); + computer.Forward(); + const CuMatrixBase &output_mat = computer.GetOutput(output_name); + CuSubVector output_vec(output_mat, 0); + output->CopyFromVec(output_vec); +} + } // namespace nnet3 } // namespace kaldi diff --git a/src/nnet3/nnet-utils.h b/src/nnet3/nnet-utils.h index 149c0e08485..f2f5e19d34a 100644 --- a/src/nnet3/nnet-utils.h +++ b/src/nnet3/nnet-utils.h @@ -144,12 +144,22 @@ int32 NumUpdatableComponents(const Nnet &dest); void ConvertRepeatedToBlockAffine(Nnet *nnet); /// This function returns various info about the neural net. -/// If the nnet satisfied IsSimpleNnet(nnet), the info includes "left-context=5\nright-context=3\n...". The info includes +/// If the nnet satisfied IsSimpleNnet(nnet), the info includes +/// "left-context=5\nright-context=3\n...". The info includes /// the output of nnet.Info(). /// This is modeled after the info that AmNnetSimple returns in its /// Info() function (we need this in the CTC code). std::string NnetInfo(const Nnet &nnet); +/// Returns a string containing info on an output node called 's' if +/// it can be interpreted as a symmetric matrix. +std::string SpMatrixOutputInfo(const Nnet &nnet); + +/// This function assumes that the node named in 'output_node' is a constant +/// function of the input features (e.g, a ConstantFunctionComponent is +/// its input) and returns it in 'out'. +void GetConstantOutput(const Nnet &nnet, const std::string &output_name, + Vector *out); } // namespace nnet3 } // namespace kaldi diff --git a/src/nnet3bin/nnet3-info.cc b/src/nnet3bin/nnet3-info.cc index 6b7fb2c629e..6a9c5c402f3 100644 --- a/src/nnet3bin/nnet3-info.cc +++ b/src/nnet3bin/nnet3-info.cc @@ -20,6 +20,7 @@ #include "base/kaldi-common.h" #include "util/common-utils.h" #include "nnet3/nnet-nnet.h" +#include "nnet3/nnet-utils.h" int main(int argc, char *argv[]) { try { @@ -35,22 +36,23 @@ int main(int argc, char *argv[]) { "e.g.:\n" " nnet3-info 0.raw\n" "See also: nnet3-am-info\n"; - + ParseOptions po(usage); - + po.Read(argc, argv); - + if (po.NumArgs() != 1) { po.PrintUsage(); exit(1); } std::string raw_nnet_rxfilename = po.GetArg(1); - + Nnet nnet; ReadKaldiObject(raw_nnet_rxfilename, &nnet); std::cout << nnet.Info(); + std::cout << SpMatrixOutputInfo(nnet); return 0; } catch(const std::exception &e) { @@ -72,6 +74,6 @@ component-node name=affine1_node component=affine1 input=Append(Offset(input, -4 component-node name=nonlin1 component=relu1 input=affine1_node component-node name=final_affine component=final_affine input=nonlin1 component-node name=output_nonlin component=logsoftmax input=final_affine -output-node name=output input=output_nonlin +output-node name=output input=output_nonlin EOF */ diff --git a/src/xvectorbin/nnet3-xvector-compute.cc b/src/xvectorbin/nnet3-xvector-compute.cc index 5d023ab4a3a..bad6c1aca63 100644 --- a/src/xvectorbin/nnet3-xvector-compute.cc +++ b/src/xvectorbin/nnet3-xvector-compute.cc @@ -179,16 +179,18 @@ int main(int argc, char *argv[]) { for (int32 i = out_offset; i < std::min(out_offset + xvector_period, num_rows); i++) xvector_mat.Row(i).CopyFromVec(xvector); - } else + } else { xvector_mat.Row(chunk_indx).CopyFromVec(xvector); + } } // If output is a vector, scale it by the total weight. if (output_as_vector) { xvector_avg.Scale(1.0 / total_chunk_weight); vector_writer.Write(utt, xvector_avg); - } else + } else { matrix_writer.Write(utt, xvector_mat); + } frame_count += feats.NumRows(); num_success++;