Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming inference with CNN #4097

Merged
merged 1 commit into from
Jan 26, 2018
Merged

Streaming inference with CNN #4097

merged 1 commit into from
Jan 26, 2018

Conversation

lucidtronix
Copy link
Contributor

Annotate a VCF with scores from a pretrained model. Stream from java to python.

@codecov-io
Copy link

codecov-io commented Jan 10, 2018

Codecov Report

Merging #4097 into master will decrease coverage by 0.492%.
The diff coverage is 84.848%.

@@               Coverage Diff               @@
##              master     #4097       +/-   ##
===============================================
- Coverage     78.458%   77.966%   -0.491%     
- Complexity     16444     17700     +1256     
===============================================
  Files           1039      1084       +45     
  Lines          59196     64175     +4979     
  Branches        9692     10632      +940     
===============================================
+ Hits           46444     50035     +3591     
- Misses          8995     10124     +1129     
- Partials        3757      4016      +259
Impacted Files Coverage Δ Complexity Δ
...nder/utils/runtime/StreamingProcessController.java 69.547% <ø> (-0.823%) 50 <0> (ø)
...ute/hellbender/utils/variant/GATKVCFConstants.java 80% <ø> (ø) 4 <0> (ø) ⬇️
...e/hellbender/utils/variant/GATKVCFHeaderLines.java 99.286% <100%> (+0.005%) 10 <0> (ø) ⬇️
...lbender/tools/walkers/vqsr/NeuralNetInference.java 84.733% <84.733%> (ø) 20 <20> (?)
.../DiscoverVariantsFromContigAlignmentsSAMSpark.java 71.839% <0%> (-28.161%) 37% <0%> (+24%)
...adinstitute/hellbender/tools/IndexFeatureFile.java 94.444% <0%> (-5.556%) 17% <0%> (+5%)
...oadinstitute/hellbender/utils/GenomeLocParser.java 84.848% <0%> (-3.03%) 57% <0%> (-2%)
...tools/walkers/mutect/SomaticLikelihoodsEngine.java 83.871% <0%> (-2.971%) 22% <0%> (+8%)
...e/hellbender/tools/spark/sv/utils/SVVCFWriter.java 86.047% <0%> (-1.709%) 10% <0%> (-1%)
...ignment/AssemblyContigWithFineTunedAlignments.java 42.105% <0%> (-1.316%) 15% <0%> (-1%)
... and 106 more

@lucidtronix lucidtronix force-pushed the sf_nn_streaming_inference branch 4 times, most recently from 2bd8226 to e27d5e3 Compare January 12, 2018 17:13
@lucidtronix lucidtronix reopened this Jan 12, 2018
@cmnbroad cmnbroad self-requested a review January 16, 2018 15:21
@cmnbroad
Copy link
Collaborator

cmnbroad commented Jan 16, 2018

@lucidtronix I have a few comments on the on the java side of this, and want to do a review pass. Let me know if/when its ready for that (it may already be, now that tests are passing).

@lucidtronix
Copy link
Contributor Author

Go for it!
I'll add the 2D CNN in a separate PR after we iron this one out...

@lucidtronix
Copy link
Contributor Author

@cmnbroad I accidentally added the file AddScores.java in this PR, please ignore, I will remove it.

Copy link
Collaborator

@cmnbroad cmnbroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First round done - it will probably take one more round after these changes are made. I'm going to submit the timeout changes for the script executor in a separate PR, so once those are in this can be rebased on that.


/**
* Created by sam on 11/17/17.
*/
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This javadoc shows up in the online doc. It can be sparse for now, but should say something more descriptive, and we generally don't list the author.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleaned up

/**
* Created by sam on 11/17/17.
*/
@CommandLineProgramProperties(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably have either a @Beta or @Experimental annotation (probably @Experimental).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

programGroup = VariantEvaluationProgramGroup.class
)

public class NeuralNetStreamingExecutor extends VariantWalker {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is 2d going to be a separate tool, or a mode of this tool ? We should probably pick a more descriptive name that doesn't have "StreamingExecutor" in it. So maybe something that will be symmetric with the names of the companion tools (i.e., training and/or 2d) once they're available.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 2d will be a mode of this tool. Renamed to NeuralNetInference.

@Argument(fullName = StandardArgumentDefinitions.OUTPUT_LONG_NAME,
shortName = StandardArgumentDefinitions.OUTPUT_SHORT_NAME,
doc = "Output file")
private File outputFile = null; // output file produced by Python code
Copy link
Collaborator

@cmnbroad cmnbroad Jan 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to update this tool to write the final output file using the tool variant writer, and update this comment.

private boolean keepInfo = true;

@Argument(fullName = "python-batch-size", shortName = "pbs", doc = "Size of batches for python to do inference.", optional = true)
private int pythonBatchSize = 256;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I probably originated this name on one of my branches, but we should change it to something more descriptive. I'd suggest maybe inference-batch-size.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this argument be @Advanced, and/or have minValue and maxValue attributes on it (the Argument annotation has attributes for that). My experience using it was that even moderately larger values like 16k or 32k allocated huge gobs of memory).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed and added min max and advanced.

pythonExecutor.getAccumulatedOutput();
}
final String pythonCommand = String.format(
"vqsr_cnn.score_and_write_batch(model, tempFile, fifoFile, %d, %d)", curBatchSize, pythonBatchSize) + NL;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you factor out the code thats duplicated here and in onTraversalSuccess, especially so the python code appears in only one place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

asyncWriter.startAsynchronousBatchWrite(batchList);
waitforBatchCompletion = true;
curBatchSize = 0;
batchList = new ArrayList<>(pythonSyncFrequency);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whole code block (line 145 to here) isn't covered by the test - it never syncs until traversal is finished because there aren't many variants. I'm not sure how long the tests takes to run - but could we add another (duplicate of the original) test using smaller batch/frequency values to force it through here ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

try {
spec.executeTest("testInference", this);
} catch (IOException e) {
e.printStackTrace();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just declare "throws IOException" in the method signature, and then you can remove the try/catch block altogether.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

.addArgument("architecture", architectureHD5)
.addArgument(StandardArgumentDefinitions.ADD_OUTPUT_VCF_COMMANDLINE, "false");

runCommandLine(argsBuilder);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests should be written using either ArgumentsBuilder/runCommandLine, or IntegrationTestSpec, but not both. IntegrationTestSpec is "old-style", but convenient. It automatically generates a temporary output file and substitutes the name for %s, and also compares the output to the expected results file. If you use runCommandLine, you need to generate your own temporary output file, and do your own expected results comparison. Since this method mixes both styles, it runs the test twice - first using runCommandLine with an output filename of literal "%s", and then again via executeTest using a generated temp file. Take a look at SelectVariantsIntegrationTest as an example of using InegrationTestSpec, or PrintReadSparkIntegrationTest for the other style.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

String varData = getVariantDataString(variant, referenceContext);
String isSnp = variant.isSNP() ? "1" : "0";
String genos = "\t.";
if (noSamples) genos = "";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These variables all could use less cryptic names; most of them are only used once though so they could be inlined below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inlined

@lucidtronix
Copy link
Contributor Author

Responded to most of the comments, but still need to implement proper gatk style vcf writing from java. Also I copied @mbabadi's python package setup, but I havent been able to upload to pypi with setup_vqsr_cnn.py so for the time being I also have a setup.py inside the vqsr_cnn package which works with pypi, but hardcodes the version. I'm sure there is a better way. Some python tests failed but it seems to be a maven jar issue...

@lucidtronix
Copy link
Contributor Author

Added intermediate temp file from python and proper VCF writing as we discussed, back to you @cmnbroad.

Copy link
Collaborator

@cmnbroad cmnbroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few more comments. I think we're close. My PR with the timeout changes is #4218, and should be reviewed soon. I'm hoping we can get all of this in for the point release, which is scheduled for Friday.


private void addScoresToVCF(){
try {
Scanner scoreScan = new Scanner(new File(scoreFile));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Scanner should be created inside of a try-with-resources stmt (you'll have to create the File object outside of the try block) so it will always be automatically closed, even if an exception is thrown. Also as mentioned above we should make sure its deleted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

});

} catch (IOException e) {
e.printStackTrace();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than calling printStackTrace, wrap this exception in a GATKException and re-throw it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize in looking at this (post traversal code) that we don't have a sanctioned way to do a second pass over the input data. We can leave this for now, but we'll probably need to add engine functionality to support this, i.e., a TwoPassVariantWalker.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, and yeah I just grabbed this from variantWalkerBase traverse.

&& variant.getReference().getBaseString().equals(scoredVariant[2])
&& variant.getAlternateAlleles().toString().equals(scoredVariant[3])) {
final VariantContextBuilder builder = new VariantContextBuilder(variant);
builder.attribute(GATKVCFConstants.CNN_1D_KEY, scoredVariant[4]);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you replace these 0,1,2,3 constants with symbolic constants saying what they represent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

if(variant.getContig().equals(scoredVariant[0])
&& Integer.toString(variant.getStart()).equals(scoredVariant[1])
&& variant.getReference().getBaseString().equals(scoredVariant[2])
&& variant.getAlternateAlleles().toString().equals(scoredVariant[3])) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this handle joining variants that are not snps/indels ("OTHER") correctly ? I'm just not sure whats getting written out for those.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dug into those a bit and the ones I saw were all multiallelic sites that have one SNP allele and the other allele is a INDEL. For now they get processed like any other variant and scored as if they were SNPs. This is not ideal, maybe we should average the SNP and INDEL score and write that. I will ask @ldgauthier what she thinks. It is very few sites so I don't think we should worry too much about them right now.

private String getVariantInfoString(final VariantContext variant){
String varInfo = "";
for (final String attributeKey : variant.getAttributes().keySet()) {
varInfo += attributeKey + "=" + variant.getAttribute(attributeKey).toString().replace(" ", "").replace("[", "").replace("]", "") + ";";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you stick a comment in here saying this is creating a python dictionary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

}


private void addScoresToVCF(){
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rename this -maybe writeOutputVCFWithScores or something like that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

} catch (IOException e) {
e.printStackTrace();
}
vcfWriter.close();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're always creating the input writer, but only closing it if we get here, on success. It would be better to only create the writer when we need it (probably best, since that way we don't create half-baked header only output file if there is a failure).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

elif b in ambiguity_codes:
dna_data[i] = ambiguity_codes[b]
else:
print('Error! Unknown code:', b)
Copy link
Collaborator

@cmnbroad cmnbroad Jan 22, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will get silently swallowed by the java front end since there is no exception/traceback here. If this is fatal, it should raise an exception. Not sure if there are other similar instances anywhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

//
// runCommandLine(argsBuilder);
//
// }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be removed ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably yes, but it is still helpful as everytime I update the model I uncomment and use this to generate a new expected VCF.

variant.getAlternateAlleles().toString()
);

return varData;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't really need the intermediate variable, but if you keep it, can be final.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Copy link
Collaborator

@cmnbroad cmnbroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of last code cleanup requests. It might be a good time to squash the commits down, remove the timeout commits, and rebase on my timeout branch. Although if you want to wait until my branch is merged thats fine too.

final StreamingPythonScriptExecutor pythonExecutor = new StreamingPythonScriptExecutor(true);

private FileOutputStream fifoWriter;
private VariantContextWriter vcfWriter;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This scope of this variable can be reduced now. See comment below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

pythonExecutor.sendSynchronousCommand(String.format("model = load_model('%s', custom_objects=vqsr_cnn.get_metric_dict())", architecture) + NL);
logger.info("Loaded CNN architecture:"+architecture);
} catch (IOException e) {
e.printStackTrace();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrap in GATKException and re-throw.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


private void writeOutputVCFWithScores(){
writeVCFHeader();
try (Scanner scoreScan = new Scanner(scoreFile)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove the class level vcfWriter variable, and localize it to try-with-resources:

        try (final Scanner scoreScan = new Scanner(scoreFile);
             final VariantContextWriter vcfWriter = createVCFWriter(new File(outputFile))) {
            scoreScan.useDelimiter("\\n");
            writeVCFHeader(vcfWriter); // or call this getOutputHeader, have it return the header, and write it here

Then all the closing is handled automatically (you can remove the explicit vcfWriter.close() below), and the resource handling will be nicely symmetric.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice cleanup, fixed

@lucidtronix
Copy link
Contributor Author

Thanks for the speedy reviews! Squashed and rebased and made the changes.

@cmnbroad
Copy link
Collaborator

@lucidtronix #4218 is merged now so you should be able to rebase this on master. It looks like when you squashed you left in some of the timeout changes, so you'll have to resolve the resulting conflicts in favor of master.

@lucidtronix
Copy link
Contributor Author

Ok rebased on master, if tests pass do you think it's ready to merge?

Copy link
Collaborator

@cmnbroad cmnbroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are still a couple of files included that should be reverted completely (that have changes left over from the previous merging of branches). We should remove those, and then we can merge once tests pass. I do still have some minor code pattern comments, but we can fix those the 2d branch. And we probably need more test coverage before we remove @Experimental - we should have a ticket for that.

I'm still lobbying for a better tool name....Otherwise looks good!

programGroup = VariantEvaluationProgramGroup.class
)

public class NeuralNetInference extends VariantWalker {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still think we need a name thats more specific. But we can change it later.

@@ -7,6 +7,7 @@
import org.broadinstitute.hellbender.utils.runtime.ProcessOutput;
import org.broadinstitute.hellbender.utils.runtime.StreamingPythonTestUtils;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to revert this whole file.

@@ -17,6 +18,8 @@
import java.util.LinkedHashMap;
import java.util.Map;

import static java.lang.Thread.sleep;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here (revert the whole file).

@lucidtronix
Copy link
Contributor Author

Yes we're still brainstorming for a better name, but it can wait for the next PR. I removed those files and rebased.

@cmnbroad
Copy link
Collaborator

@lucidtronix It looks like the StreamingPythonExecutorUnitTest and ProcessControllerUnitTest files are entirely removed now, instead of just being reverted (they had some stray changes included before). Those files need to be restored, then we can merge once tests pass again.

@lucidtronix
Copy link
Contributor Author

Ooops! They're back now and checks passed...

Copy link
Collaborator

@cmnbroad cmnbroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All right then.

@cmnbroad cmnbroad merged commit 25f96d4 into master Jan 26, 2018
@cmnbroad cmnbroad deleted the sf_nn_streaming_inference branch January 26, 2018 20:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants