-
Notifications
You must be signed in to change notification settings - Fork 593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Streaming inference with CNN #4097
Conversation
Codecov Report
@@ Coverage Diff @@
## master #4097 +/- ##
===============================================
- Coverage 78.458% 77.966% -0.491%
- Complexity 16444 17700 +1256
===============================================
Files 1039 1084 +45
Lines 59196 64175 +4979
Branches 9692 10632 +940
===============================================
+ Hits 46444 50035 +3591
- Misses 8995 10124 +1129
- Partials 3757 4016 +259
|
2bd8226
to
e27d5e3
Compare
@lucidtronix I have a few comments on the on the java side of this, and want to do a review pass. Let me know if/when its ready for that (it may already be, now that tests are passing). |
Go for it! |
@cmnbroad I accidentally added the file AddScores.java in this PR, please ignore, I will remove it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First round done - it will probably take one more round after these changes are made. I'm going to submit the timeout changes for the script executor in a separate PR, so once those are in this can be rebased on that.
|
||
/** | ||
* Created by sam on 11/17/17. | ||
*/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This javadoc shows up in the online doc. It can be sparse for now, but should say something more descriptive, and we generally don't list the author.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cleaned up
/** | ||
* Created by sam on 11/17/17. | ||
*/ | ||
@CommandLineProgramProperties( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should probably have either a @Beta
or @Experimental
annotation (probably @Experimental
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
programGroup = VariantEvaluationProgramGroup.class | ||
) | ||
|
||
public class NeuralNetStreamingExecutor extends VariantWalker { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is 2d going to be a separate tool, or a mode of this tool ? We should probably pick a more descriptive name that doesn't have "StreamingExecutor" in it. So maybe something that will be symmetric with the names of the companion tools (i.e., training and/or 2d) once they're available.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think 2d will be a mode of this tool. Renamed to NeuralNetInference.
@Argument(fullName = StandardArgumentDefinitions.OUTPUT_LONG_NAME, | ||
shortName = StandardArgumentDefinitions.OUTPUT_SHORT_NAME, | ||
doc = "Output file") | ||
private File outputFile = null; // output file produced by Python code |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to update this tool to write the final output file using the tool variant writer, and update this comment.
private boolean keepInfo = true; | ||
|
||
@Argument(fullName = "python-batch-size", shortName = "pbs", doc = "Size of batches for python to do inference.", optional = true) | ||
private int pythonBatchSize = 256; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I probably originated this name on one of my branches, but we should change it to something more descriptive. I'd suggest maybe inference-batch-size.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this argument be @Advanced
, and/or have minValue
and maxValue
attributes on it (the Argument annotation has attributes for that). My experience using it was that even moderately larger values like 16k or 32k allocated huge gobs of memory).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed and added min max and advanced.
pythonExecutor.getAccumulatedOutput(); | ||
} | ||
final String pythonCommand = String.format( | ||
"vqsr_cnn.score_and_write_batch(model, tempFile, fifoFile, %d, %d)", curBatchSize, pythonBatchSize) + NL; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you factor out the code thats duplicated here and in onTraversalSuccess
, especially so the python code appears in only one place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
asyncWriter.startAsynchronousBatchWrite(batchList); | ||
waitforBatchCompletion = true; | ||
curBatchSize = 0; | ||
batchList = new ArrayList<>(pythonSyncFrequency); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This whole code block (line 145 to here) isn't covered by the test - it never syncs until traversal is finished because there aren't many variants. I'm not sure how long the tests takes to run - but could we add another (duplicate of the original) test using smaller batch/frequency values to force it through here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
try { | ||
spec.executeTest("testInference", this); | ||
} catch (IOException e) { | ||
e.printStackTrace(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just declare "throws IOException" in the method signature, and then you can remove the try/catch block altogether.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
.addArgument("architecture", architectureHD5) | ||
.addArgument(StandardArgumentDefinitions.ADD_OUTPUT_VCF_COMMANDLINE, "false"); | ||
|
||
runCommandLine(argsBuilder); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests should be written using either ArgumentsBuilder/runCommandLine, or IntegrationTestSpec, but not both. IntegrationTestSpec is "old-style", but convenient. It automatically generates a temporary output file and substitutes the name for %s, and also compares the output to the expected results file. If you use runCommandLine, you need to generate your own temporary output file, and do your own expected results comparison. Since this method mixes both styles, it runs the test twice - first using runCommandLine with an output filename of literal "%s", and then again via executeTest using a generated temp file. Take a look at SelectVariantsIntegrationTest as an example of using InegrationTestSpec, or PrintReadSparkIntegrationTest for the other style.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
String varData = getVariantDataString(variant, referenceContext); | ||
String isSnp = variant.isSNP() ? "1" : "0"; | ||
String genos = "\t."; | ||
if (noSamples) genos = ""; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These variables all could use less cryptic names; most of them are only used once though so they could be inlined below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inlined
Responded to most of the comments, but still need to implement proper gatk style vcf writing from java. Also I copied @mbabadi's python package setup, but I havent been able to upload to pypi with setup_vqsr_cnn.py so for the time being I also have a setup.py inside the vqsr_cnn package which works with pypi, but hardcodes the version. I'm sure there is a better way. Some python tests failed but it seems to be a maven jar issue... |
Added intermediate temp file from python and proper VCF writing as we discussed, back to you @cmnbroad. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few more comments. I think we're close. My PR with the timeout changes is #4218, and should be reviewed soon. I'm hoping we can get all of this in for the point release, which is scheduled for Friday.
|
||
private void addScoresToVCF(){ | ||
try { | ||
Scanner scoreScan = new Scanner(new File(scoreFile)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Scanner should be created inside of a try-with-resources stmt (you'll have to create the File object outside of the try block) so it will always be automatically closed, even if an exception is thrown. Also as mentioned above we should make sure its deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
}); | ||
|
||
} catch (IOException e) { | ||
e.printStackTrace(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than calling printStackTrace, wrap this exception in a GATKException and re-throw it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realize in looking at this (post traversal code) that we don't have a sanctioned way to do a second pass over the input data. We can leave this for now, but we'll probably need to add engine functionality to support this, i.e., a TwoPassVariantWalker.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed, and yeah I just grabbed this from variantWalkerBase traverse.
&& variant.getReference().getBaseString().equals(scoredVariant[2]) | ||
&& variant.getAlternateAlleles().toString().equals(scoredVariant[3])) { | ||
final VariantContextBuilder builder = new VariantContextBuilder(variant); | ||
builder.attribute(GATKVCFConstants.CNN_1D_KEY, scoredVariant[4]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you replace these 0,1,2,3 constants with symbolic constants saying what they represent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
if(variant.getContig().equals(scoredVariant[0]) | ||
&& Integer.toString(variant.getStart()).equals(scoredVariant[1]) | ||
&& variant.getReference().getBaseString().equals(scoredVariant[2]) | ||
&& variant.getAlternateAlleles().toString().equals(scoredVariant[3])) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this handle joining variants that are not snps/indels ("OTHER") correctly ? I'm just not sure whats getting written out for those.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dug into those a bit and the ones I saw were all multiallelic sites that have one SNP allele and the other allele is a INDEL. For now they get processed like any other variant and scored as if they were SNPs. This is not ideal, maybe we should average the SNP and INDEL score and write that. I will ask @ldgauthier what she thinks. It is very few sites so I don't think we should worry too much about them right now.
private String getVariantInfoString(final VariantContext variant){ | ||
String varInfo = ""; | ||
for (final String attributeKey : variant.getAttributes().keySet()) { | ||
varInfo += attributeKey + "=" + variant.getAttribute(attributeKey).toString().replace(" ", "").replace("[", "").replace("]", "") + ";"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you stick a comment in here saying this is creating a python dictionary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
} | ||
|
||
|
||
private void addScoresToVCF(){ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rename this -maybe writeOutputVCFWithScores or something like that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
} catch (IOException e) { | ||
e.printStackTrace(); | ||
} | ||
vcfWriter.close(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're always creating the input writer, but only closing it if we get here, on success. It would be better to only create the writer when we need it (probably best, since that way we don't create half-baked header only output file if there is a failure).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
elif b in ambiguity_codes: | ||
dna_data[i] = ambiguity_codes[b] | ||
else: | ||
print('Error! Unknown code:', b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will get silently swallowed by the java front end since there is no exception/traceback here. If this is fatal, it should raise an exception. Not sure if there are other similar instances anywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
// | ||
// runCommandLine(argsBuilder); | ||
// | ||
// } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be removed ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably yes, but it is still helpful as everytime I update the model I uncomment and use this to generate a new expected VCF.
variant.getAlternateAlleles().toString() | ||
); | ||
|
||
return varData; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't really need the intermediate variable, but if you keep it, can be final.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple of last code cleanup requests. It might be a good time to squash the commits down, remove the timeout commits, and rebase on my timeout branch. Although if you want to wait until my branch is merged thats fine too.
final StreamingPythonScriptExecutor pythonExecutor = new StreamingPythonScriptExecutor(true); | ||
|
||
private FileOutputStream fifoWriter; | ||
private VariantContextWriter vcfWriter; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This scope of this variable can be reduced now. See comment below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
pythonExecutor.sendSynchronousCommand(String.format("model = load_model('%s', custom_objects=vqsr_cnn.get_metric_dict())", architecture) + NL); | ||
logger.info("Loaded CNN architecture:"+architecture); | ||
} catch (IOException e) { | ||
e.printStackTrace(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrap in GATKException and re-throw.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
|
||
private void writeOutputVCFWithScores(){ | ||
writeVCFHeader(); | ||
try (Scanner scoreScan = new Scanner(scoreFile)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you remove the class level vcfWriter variable, and localize it to try-with-resources:
try (final Scanner scoreScan = new Scanner(scoreFile);
final VariantContextWriter vcfWriter = createVCFWriter(new File(outputFile))) {
scoreScan.useDelimiter("\\n");
writeVCFHeader(vcfWriter); // or call this getOutputHeader, have it return the header, and write it here
Then all the closing is handled automatically (you can remove the explicit vcfWriter.close() below), and the resource handling will be nicely symmetric.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice cleanup, fixed
4d90134
to
52d0120
Compare
Thanks for the speedy reviews! Squashed and rebased and made the changes. |
@lucidtronix #4218 is merged now so you should be able to rebase this on master. It looks like when you squashed you left in some of the timeout changes, so you'll have to resolve the resulting conflicts in favor of master. |
52d0120
to
f6a8058
Compare
Ok rebased on master, if tests pass do you think it's ready to merge? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are still a couple of files included that should be reverted completely (that have changes left over from the previous merging of branches). We should remove those, and then we can merge once tests pass. I do still have some minor code pattern comments, but we can fix those the 2d branch. And we probably need more test coverage before we remove @Experimental
- we should have a ticket for that.
I'm still lobbying for a better tool name....Otherwise looks good!
programGroup = VariantEvaluationProgramGroup.class | ||
) | ||
|
||
public class NeuralNetInference extends VariantWalker { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still think we need a name thats more specific. But we can change it later.
@@ -7,6 +7,7 @@ | |||
import org.broadinstitute.hellbender.utils.runtime.ProcessOutput; | |||
import org.broadinstitute.hellbender.utils.runtime.StreamingPythonTestUtils; | |||
import org.testng.Assert; | |||
import org.testng.annotations.AfterClass; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should be able to revert this whole file.
@@ -17,6 +18,8 @@ | |||
import java.util.LinkedHashMap; | |||
import java.util.Map; | |||
|
|||
import static java.lang.Thread.sleep; | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here (revert the whole file).
f6a8058
to
8f3fc41
Compare
Yes we're still brainstorming for a better name, but it can wait for the next PR. I removed those files and rebased. |
@lucidtronix It looks like the StreamingPythonExecutorUnitTest and ProcessControllerUnitTest files are entirely removed now, instead of just being reverted (they had some stray changes included before). Those files need to be restored, then we can merge once tests pass again. |
8f3fc41
to
f5b949c
Compare
Ooops! They're back now and checks passed... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All right then.
Annotate a VCF with scores from a pretrained model. Stream from java to python.