-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jamr rebase #8
base: master
Are you sure you want to change the base?
Jamr rebase #8
Changes from all commits
9278ab3
3dc482c
48ba8fd
8c61936
290e65f
c3c4f95
d1bdfae
d6712ab
2db8dd2
f86dadd
ab6cda7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,10 @@ package edu.cmu.lti.nlp.amr | |
|
||
import java.io.StringWriter | ||
import java.io.PrintWriter | ||
import java.io.PrintStream | ||
|
||
import edu.cmu.lti.nlp.amr.BasicFeatureVector.DecoderResult | ||
|
||
import scala.io.Source.fromFile | ||
import scala.collection.mutable.Map | ||
import scala.collection.mutable.Set | ||
|
@@ -29,6 +33,7 @@ scala -classpath . edu.cmu.lti.nlp.amr.AMRParser --stage2-decode -w weights -l l | |
def isSwitch(s : String) = (s(0) == '-') | ||
list match { | ||
case Nil => map | ||
case "--model-name" :: value :: tail => parseOptions(map + ('modelName -> value), tail) | ||
case "--stage1-only" :: l => parseOptions(map + ('stage1Only -> "true"), l) | ||
case "--stage1-oracle" :: l => parseOptions(map + ('stage1Oracle -> "true"), l) | ||
case "--stage1-train" :: l => parseOptions(map + ('stage1Train -> "true"), l) | ||
|
@@ -64,6 +69,8 @@ scala -classpath . edu.cmu.lti.nlp.amr.AMRParser --stage2-decode -w weights -l l | |
case "--ner" :: value :: tail => parseOptions(map + ('ner -> value), tail) | ||
case "--snt" :: value :: tail => parseOptions(map ++ Map('notTokenized -> value), tail) | ||
case "--tok" :: value :: tail => parseOptions(map ++ Map('tokenized -> value), tail) | ||
case "--input" :: value :: tail => parseOptions(map ++ Map('input -> value), tail) | ||
case "--output" :: value :: tail => parseOptions(map ++ Map('output -> value), tail) | ||
case "-v" :: value :: tail => parseOptions(map ++ Map('verbosity -> value), tail) | ||
|
||
//case string :: opt2 :: tail if isSwitch(opt2) => parseOptions(map ++ Map('infile -> string), list.tail) | ||
|
@@ -77,37 +84,52 @@ scala -classpath . edu.cmu.lti.nlp.amr.AMRParser --stage2-decode -w weights -l l | |
val now = System.nanoTime | ||
val result = a | ||
val micros = (System.nanoTime - now) / 1000 | ||
System.err.println("Decoded in %,d microseconds".format(micros)) | ||
logger(0,"Decoded in %,d microseconds".format(micros)) | ||
result | ||
} | ||
|
||
// cache the previous Stage2 GraphDecoder object so that it can be re-used. This saves us from having to reload | ||
// the weights if JAMR is re-invoked in the same process. Save multiple version using modelName is a key | ||
var previousStage2: Map[String, GraphDecoder.Decoder] = Map() | ||
|
||
var previousStage1: Map[String, ConceptInvoke.Decoder] = Map() | ||
|
||
// an optional callback that lets us get direct access to decoderResultGraph when parsing | ||
var resultHandler: Option[Graph => Unit] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see this ever being set to anything other than Also, adding side-effecting callbacks makes the code harder to grok/change/debug/test. Is there a way to extract a method for the part of the code that produces |
||
|
||
def main(args: Array[String]) { | ||
|
||
if (args.length == 0) { println(usage); sys.exit(1) } | ||
val options = parseOptions(Map(),args.toList) | ||
|
||
verbosity = options.getOrElse('verbosity, "0").toInt | ||
|
||
val modelName = options.getOrElse('modelName, "Default").toString | ||
|
||
val outputFormat = options.getOrElse('outputFormat,"triples").split(",").toList | ||
// Output format is comma separated list of: nodes,edges,AMR,triples | ||
|
||
val stage1 : ConceptInvoke.Decoder = { | ||
val stage1 : Option[ConceptInvoke.Decoder] = if (previousStage1.contains(modelName)) previousStage1.get(modelName) else { | ||
if (!options.contains('stage1Oracle) && !options.contains('stage2Train)) { | ||
ConceptInvoke.Decoder(options, oracle = false) | ||
Some(ConceptInvoke.Decoder(options, oracle = false)) | ||
} else { | ||
assert(!options.contains('stage1Train), "Error: --stage1-oracle should not be specified with --stage1-train") | ||
ConceptInvoke.Decoder(options, oracle = true) | ||
Some(ConceptInvoke.Decoder(options, oracle = true)) | ||
} | ||
} | ||
|
||
val stage2 : Option[GraphDecoder.Decoder] = { | ||
|
||
val stage2 : Option[GraphDecoder.Decoder] = if (previousStage2.contains(modelName)) previousStage2.get(modelName) | ||
else { | ||
if((options.contains('stage1Only) || options.contains('stage1Train)) && !options.contains('stage2Train)) { | ||
None | ||
} else { | ||
Some(GraphDecoder.Decoder(options)) | ||
} | ||
None | ||
} else { | ||
Some(GraphDecoder.Decoder(options)) | ||
} | ||
} | ||
|
||
if (stage2.isDefined) previousStage2 += (modelName -> stage2.get) | ||
|
||
val stage2Oracle : Option[GraphDecoder.Decoder] = { | ||
if(options.contains('trainingData) || options.contains('stage2Train)) { | ||
Some(GraphDecoder.Oracle(options)) | ||
|
@@ -121,7 +143,7 @@ scala -classpath . edu.cmu.lti.nlp.amr.AMRParser --stage2-decode -w weights -l l | |
////////////////// Training //////////////// | ||
|
||
if (options.contains('stage1Train) && options.contains('stage2Train)) { | ||
System.err.println("Error: please specify either stage1 training or stage2 training (not both)") | ||
logger(0,"Error: please specify either stage1 training or stage2 training (not both)") | ||
sys.exit(1) | ||
} | ||
|
||
|
@@ -144,28 +166,33 @@ scala -classpath . edu.cmu.lti.nlp.amr.AMRParser --stage2-decode -w weights -l l | |
/////////////////// Decoding ///////////////// | ||
|
||
if (!options.contains('stage1Weights)) { | ||
System.err.println("Error: No stage1 weights file specified"); sys.exit(1) | ||
logger(0,"Error: No stage1 weights file specified"); sys.exit(1) | ||
} | ||
stage1.features.weights.read(Source.fromFile(options('stage1Weights).asInstanceOf[String]).getLines()) | ||
stage1.get.features.weights.read(Source.fromFile(options('stage1Weights).asInstanceOf[String]).getLines()) | ||
|
||
//logger(0, "Stage1 weights:\n"+stage1.features.weights.toString) | ||
|
||
if (!options.contains('stage2Weights)) { | ||
System.err.println("Error: No stage2 weights file specified") | ||
logger(0,"Error: No stage2 weights file specified") | ||
sys.exit(1) | ||
} | ||
val stage2weightfile : String = options('stage2Weights) | ||
|
||
logger(0, "Reading weights") | ||
|
||
if (stage2 != None) { | ||
stage2.get.features.weights.read(Source.fromFile(stage2weightfile).getLines()) | ||
if (stage2Oracle != None) { | ||
stage2Oracle.get.features.weights.read(Source.fromFile(stage2weightfile).getLines()) | ||
// don't read the weights again if they are already there, this check is probably redundant | ||
if (stage2.get.features.weights.fmap.size == 0) { | ||
logger(0, f"Reading weights for $modelName") | ||
val stage2weightlines = Source.fromFile( stage2weightfile ).getLines( ) | ||
stage2.get.features.weights.read( stage2weightlines ) | ||
if( stage2Oracle != None ) { | ||
stage2Oracle.get.features.weights.read( Source.fromFile( stage2weightfile ).getLines( ) ) | ||
} | ||
} | ||
} | ||
logger(0, "done") | ||
|
||
val input = stdin.getLines.toArray | ||
val input = if (options.contains('input)) fromFile(options('input)).getLines().toArray else stdin.getLines.toArray | ||
val tokenized = fromFile(options('tokenized).asInstanceOf[String]).getLines/*.map(x => x)*/.toArray | ||
val nerFile = Corpus.splitOnNewline(fromFile(options('ner).asInstanceOf[String]).getLines).toArray | ||
val oracleData : Array[String] = if (options.contains('trainingData)) { | ||
|
@@ -189,6 +216,8 @@ scala -classpath . edu.cmu.lti.nlp.amr.AMRParser --stage2-decode -w weights -l l | |
} | ||
val spanF1 = F1(0,0,0) | ||
|
||
val outStream = if (options.contains('output)) new PrintStream(options('output)) else System.out | ||
|
||
for ((block, i) <- input.zipWithIndex) { | ||
try { | ||
time { | ||
|
@@ -197,7 +226,7 @@ scala -classpath . edu.cmu.lti.nlp.amr.AMRParser --stage2-decode -w weights -l l | |
val tok = tokenized(i) | ||
val ner = nerFile(i) | ||
val inputGraph = if (options.contains('stage1Oracle)) { Some(AMRTrainingData(oracleData(i)).toInputGraph) } else { None } | ||
val stage1Result = stage1.decode(new Input(inputGraph, | ||
val stage1Result = stage1.get.decode(new Input(inputGraph, | ||
tok.split(" "), | ||
line.split(" "), | ||
dependencies(i), | ||
|
@@ -277,46 +306,50 @@ scala -classpath . edu.cmu.lti.nlp.amr.AMRParser --stage2-decode -w weights -l l | |
})+"\n") | ||
} | ||
|
||
println("# ::snt "+line) | ||
println("# ::tok "+tok) | ||
outStream.println("# ::snt "+line) | ||
outStream.println("# ::tok "+tok) | ||
val sdf = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS") | ||
decoderResultGraph.assignOpN() | ||
decoderResultGraph.sortRelations() | ||
decoderResultGraph.makeIds() | ||
println("# ::alignments "+decoderResultGraph.spans.map(_.format).mkString(" ")+" ::annotator "+VERSION+" ::date "+sdf.format(new Date)) | ||
|
||
// fire the callback if defined | ||
if (resultHandler.isDefined) resultHandler.get.apply(decoderResultGraph) | ||
|
||
outStream.println("# ::alignments "+decoderResultGraph.spans.map(_.format).mkString(" ")+" ::annotator "+VERSION+" ::date "+sdf.format(new Date)) | ||
if (outputFormat.contains("nodes")) { | ||
println(decoderResultGraph.printNodes.map(x => "# ::node\t" + x).mkString("\n")) | ||
outStream.println(decoderResultGraph.printNodes.map(x => "# ::node\t" + x).mkString("\n")) | ||
} | ||
if (outputFormat.contains("root")) { | ||
println(decoderResultGraph.printRoot) | ||
} | ||
if (outputFormat.contains("edges") && decoderResultGraph.root.relations.size > 0) { | ||
println(decoderResultGraph.printEdges.map(x => "# ::edge\t" + x).mkString("\n")) | ||
outStream.println(decoderResultGraph.printEdges.map(x => "# ::edge\t" + x).mkString("\n")) | ||
} | ||
if (outputFormat.contains("AMR")) { | ||
println(decoderResultGraph.prettyString(detail=1, pretty=true)) | ||
outStream.println(decoderResultGraph.prettyString(detail=1, pretty=true)) | ||
} | ||
if (outputFormat.contains("triples")) { | ||
println(decoderResultGraph.printTriples(detail = 1)) | ||
outStream.println(decoderResultGraph.printTriples(detail = 1)) | ||
} | ||
println() | ||
outStream.println() | ||
} // time | ||
} catch { // try | ||
case e : Throwable => if (options.contains('ignoreParserErrors)) { | ||
println("# ::snt "+input(i)) | ||
println("# ::tok "+tokenized(i)) | ||
outStream.println("# ::snt "+input(i)) | ||
outStream.println("# ::tok "+tokenized(i)) | ||
val sdf = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS") | ||
println("# ::alignments 0-1|0 ::annotator "+VERSION+" ::date "+sdf.format(new Date)) | ||
println("# THERE WAS AN EXCEPTION IN THE PARSER. Returning an empty graph.") | ||
outStream.println("# ::alignments 0-1|0 ::annotator "+VERSION+" ::date "+sdf.format(new Date)) | ||
outStream.println("# THERE WAS AN EXCEPTION IN THE PARSER. Returning an empty graph.") | ||
if (options.contains('printStackTraceOnErrors)) { | ||
val sw = new StringWriter() | ||
e.printStackTrace(new PrintWriter(sw)) | ||
println(sw.toString.split("\n").map(x => "# "+x).mkString("\n")) | ||
outStream.println(sw.toString.split("\n").map(x => "# "+x).mkString("\n")) | ||
} | ||
logger(-1, " ********** THERE WAS AN EXCEPTION IN THE PARSER. *********") | ||
if (verbosity >= -1) { e.printStackTrace } | ||
logger(-1, "Continuing. To exit on errors, please run without --ignore-parser-errors") | ||
println(Graph.empty.prettyString(detail=1, pretty=true) + '\n') | ||
outStream.println(Graph.empty.prettyString(detail=1, pretty=true) + '\n') | ||
} else { | ||
throw e | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
package edu.cmu.lti.nlp.amr | ||
|
||
import java.lang.Throwable | ||
import java.util.regex.Pattern | ||
import scala.util.{Failure, Success, Try} | ||
import scala.util.matching.Regex | ||
import scala.collection.mutable.Map | ||
import scala.collection.mutable.Set | ||
|
@@ -12,13 +14,20 @@ object AlignSpans2 { | |
val weird_system = if ("abc".split("").toList == List("a","b","c")) { | ||
true // on some systems "abc".split("") gives Array("a","b","c") and I have no idea why | ||
} else { | ||
false // usually "abc".split("") is Array("","a","b","c") | ||
false // usually "abc".split("") is Array("","a","b","c") | ||
} | ||
|
||
// capture spans in a string with a marker so that you can see what they look like before and after update operations | ||
def dumpSpans(marker: String, spans: ArrayBuffer[Span]): String = { | ||
var sb = new StringBuilder() | ||
for (s <- spans) sb.append(f"$marker: $s\n") | ||
sb.toString() | ||
} | ||
|
||
def align(sentence: Array[String], graph: Graph) { | ||
val stemmedSentence = sentence.map(stemmer(_)) | ||
val wordToSpan : Array[Option[Int]] = sentence.map(x => None) | ||
logger(3, "Stemmed sentence "+stemmedSentence.toList.toString) | ||
logger(1, "Stemmed sentence "+stemmedSentence.toList.toString) | ||
|
||
val namedEntity = new SpanAligner(sentence, graph) { | ||
concept = "name" | ||
|
@@ -187,14 +196,19 @@ object AlignSpans2 { | |
addAllSpans(singleConcept, graph, wordToSpan, addCoRefs=false) | ||
addAllSpans(fuzzyConcept, graph, wordToSpan, addCoRefs=false) | ||
addAllSpans(US, graph, wordToSpan, addCoRefs=false) | ||
try { updateSpans(namedEntityCollect, graph) } catch { case e : Throwable => Unit } | ||
try { updateSpans(unalignedEntity, graph) } catch { case e : Throwable => Unit } | ||
try { updateSpans(quantity, graph) } catch { case e : Throwable => Unit } | ||
try { updateSpans(argOf, graph) } catch { case e : Throwable => Unit } | ||
try { updateSpans(personOf, graph) } catch { case e : Throwable => Unit } | ||
try { updateSpans(governmentOrg, graph) } catch { case e : Throwable => Unit } | ||
try { updateSpans(polarityChild, graph) } catch { case e : Throwable => Unit } | ||
try { updateSpans(est, graph) } catch { case e : Throwable => Unit } | ||
val spansBefore = dumpSpans("before", graph.spans) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. love it |
||
updateSpans(namedEntityCollect, graph) | ||
updateSpans(unalignedEntity, graph) | ||
updateSpans(quantity, graph) | ||
updateSpans(argOf, graph) | ||
updateSpans(personOf, graph) | ||
updateSpans(governmentOrg, graph) | ||
updateSpans(polarityChild, graph) | ||
updateSpans(est, graph) | ||
val spansAfter = dumpSpans("after", graph.spans) | ||
// print out the spans for comparison before and after update | ||
logger(1, spansBefore) | ||
logger(1, spansAfter) | ||
//try { updateSpans(er, graph) } catch { case e : Throwable => Unit } | ||
//dateEntities(sentence, graph) | ||
//namedEntities(sentence, graph) | ||
|
@@ -231,7 +245,7 @@ object AlignSpans2 { | |
} | ||
|
||
def getSpans(node: Node) : List[Span] = { | ||
logger(2, "Processing node: " + node.concept) | ||
logger(2, "SpanAligner.getSpans Processing node: " + node.concept) | ||
return node match { | ||
case Node(_,_,c,_,_,_,_,_) if ((concept.r.unapplySeq(getConcept(c)) != None) && !node.isAligned(graph)) => { | ||
logger(2, "Matched concept regex: " + concept) | ||
|
@@ -264,7 +278,7 @@ object AlignSpans2 { | |
} | ||
|
||
def update(node: Node) { | ||
logger(2, "SpanUpdater processing node: " + node.concept) | ||
logger(2, "SpanUpdater.update processing node: " + node.concept) | ||
node match { | ||
case Node(_,_,c,_,_,_,_,_) if ((concept.r.unapplySeq(getConcept(c)) != None) && (!node.isAligned(graph) || !unalignedOnly)) => { | ||
logger(2, "SpanUpdater matched concept regex: " + concept) | ||
|
@@ -460,12 +474,21 @@ object AlignSpans2 { | |
} | ||
} | ||
} | ||
// invoke the add() method on each node of the graph beginning with the root | ||
logger(2, f"addAllSpans: concept=${f.concept}") | ||
graph.doRecursive(add) | ||
} | ||
|
||
private def updateSpans(f: AlignSpans2.SpanUpdater, graph: Graph) { | ||
graph.doRecursive(f.update) | ||
} | ||
private def updateSpans(f: AlignSpans2.SpanUpdater, graph: Graph) = { | ||
try { | ||
graph.doRecursive( f.update ) | ||
} | ||
catch { | ||
case e: java.lang.Throwable => { | ||
logger( 3, f"Exception in updateSpans using ${f.getClass.getCanonicalName}: " + e.getMessage) | ||
} | ||
} | ||
} | ||
|
||
/****** </This stuff was originally in Graph> *******/ | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
*.pyc