Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jamr rebase #8

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# generated files
/scripts/smatch_v1_0/amr.pyc
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*.pyc

/target
/out
/project/target
Expand All @@ -11,3 +12,5 @@
# IntelliJ cruft
.idea
*.iml
.DS_Store
/tools
99 changes: 66 additions & 33 deletions src/AMRParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ package edu.cmu.lti.nlp.amr

import java.io.StringWriter
import java.io.PrintWriter
import java.io.PrintStream

import edu.cmu.lti.nlp.amr.BasicFeatureVector.DecoderResult

import scala.io.Source.fromFile
import scala.collection.mutable.Map
import scala.collection.mutable.Set
Expand Down Expand Up @@ -29,6 +33,7 @@ scala -classpath . edu.cmu.lti.nlp.amr.AMRParser --stage2-decode -w weights -l l
def isSwitch(s : String) = (s(0) == '-')
list match {
case Nil => map
case "--model-name" :: value :: tail => parseOptions(map + ('modelName -> value), tail)
case "--stage1-only" :: l => parseOptions(map + ('stage1Only -> "true"), l)
case "--stage1-oracle" :: l => parseOptions(map + ('stage1Oracle -> "true"), l)
case "--stage1-train" :: l => parseOptions(map + ('stage1Train -> "true"), l)
Expand Down Expand Up @@ -64,6 +69,8 @@ scala -classpath . edu.cmu.lti.nlp.amr.AMRParser --stage2-decode -w weights -l l
case "--ner" :: value :: tail => parseOptions(map + ('ner -> value), tail)
case "--snt" :: value :: tail => parseOptions(map ++ Map('notTokenized -> value), tail)
case "--tok" :: value :: tail => parseOptions(map ++ Map('tokenized -> value), tail)
case "--input" :: value :: tail => parseOptions(map ++ Map('input -> value), tail)
case "--output" :: value :: tail => parseOptions(map ++ Map('output -> value), tail)
case "-v" :: value :: tail => parseOptions(map ++ Map('verbosity -> value), tail)

//case string :: opt2 :: tail if isSwitch(opt2) => parseOptions(map ++ Map('infile -> string), list.tail)
Expand All @@ -77,37 +84,52 @@ scala -classpath . edu.cmu.lti.nlp.amr.AMRParser --stage2-decode -w weights -l l
val now = System.nanoTime
val result = a
val micros = (System.nanoTime - now) / 1000
System.err.println("Decoded in %,d microseconds".format(micros))
logger(0,"Decoded in %,d microseconds".format(micros))
result
}

// cache the previous Stage2 GraphDecoder object so that it can be re-used. This saves us from having to reload
// the weights if JAMR is re-invoked in the same process. Save multiple version using modelName is a key
var previousStage2: Map[String, GraphDecoder.Decoder] = Map()

var previousStage1: Map[String, ConceptInvoke.Decoder] = Map()

// an optional callback that lets us get direct access to decoderResultGraph when parsing
var resultHandler: Option[Graph => Unit] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see this ever being set to anything other than None.

Also, adding side-effecting callbacks makes the code harder to grok/change/debug/test. Is there a way to extract a method for the part of the code that produces decoderResultGraph instead?


def main(args: Array[String]) {

if (args.length == 0) { println(usage); sys.exit(1) }
val options = parseOptions(Map(),args.toList)

verbosity = options.getOrElse('verbosity, "0").toInt

val modelName = options.getOrElse('modelName, "Default").toString

val outputFormat = options.getOrElse('outputFormat,"triples").split(",").toList
// Output format is comma separated list of: nodes,edges,AMR,triples

val stage1 : ConceptInvoke.Decoder = {
val stage1 : Option[ConceptInvoke.Decoder] = if (previousStage1.contains(modelName)) previousStage1.get(modelName) else {
if (!options.contains('stage1Oracle) && !options.contains('stage2Train)) {
ConceptInvoke.Decoder(options, oracle = false)
Some(ConceptInvoke.Decoder(options, oracle = false))
} else {
assert(!options.contains('stage1Train), "Error: --stage1-oracle should not be specified with --stage1-train")
ConceptInvoke.Decoder(options, oracle = true)
Some(ConceptInvoke.Decoder(options, oracle = true))
}
}

val stage2 : Option[GraphDecoder.Decoder] = {

val stage2 : Option[GraphDecoder.Decoder] = if (previousStage2.contains(modelName)) previousStage2.get(modelName)
else {
if((options.contains('stage1Only) || options.contains('stage1Train)) && !options.contains('stage2Train)) {
None
} else {
Some(GraphDecoder.Decoder(options))
}
None
} else {
Some(GraphDecoder.Decoder(options))
}
}

if (stage2.isDefined) previousStage2 += (modelName -> stage2.get)

val stage2Oracle : Option[GraphDecoder.Decoder] = {
if(options.contains('trainingData) || options.contains('stage2Train)) {
Some(GraphDecoder.Oracle(options))
Expand All @@ -121,7 +143,7 @@ scala -classpath . edu.cmu.lti.nlp.amr.AMRParser --stage2-decode -w weights -l l
////////////////// Training ////////////////

if (options.contains('stage1Train) && options.contains('stage2Train)) {
System.err.println("Error: please specify either stage1 training or stage2 training (not both)")
logger(0,"Error: please specify either stage1 training or stage2 training (not both)")
sys.exit(1)
}

Expand All @@ -144,28 +166,33 @@ scala -classpath . edu.cmu.lti.nlp.amr.AMRParser --stage2-decode -w weights -l l
/////////////////// Decoding /////////////////

if (!options.contains('stage1Weights)) {
System.err.println("Error: No stage1 weights file specified"); sys.exit(1)
logger(0,"Error: No stage1 weights file specified"); sys.exit(1)
}
stage1.features.weights.read(Source.fromFile(options('stage1Weights).asInstanceOf[String]).getLines())
stage1.get.features.weights.read(Source.fromFile(options('stage1Weights).asInstanceOf[String]).getLines())

//logger(0, "Stage1 weights:\n"+stage1.features.weights.toString)

if (!options.contains('stage2Weights)) {
System.err.println("Error: No stage2 weights file specified")
logger(0,"Error: No stage2 weights file specified")
sys.exit(1)
}
val stage2weightfile : String = options('stage2Weights)

logger(0, "Reading weights")

if (stage2 != None) {
stage2.get.features.weights.read(Source.fromFile(stage2weightfile).getLines())
if (stage2Oracle != None) {
stage2Oracle.get.features.weights.read(Source.fromFile(stage2weightfile).getLines())
// don't read the weights again if they are already there, this check is probably redundant
if (stage2.get.features.weights.fmap.size == 0) {
logger(0, f"Reading weights for $modelName")
val stage2weightlines = Source.fromFile( stage2weightfile ).getLines( )
stage2.get.features.weights.read( stage2weightlines )
if( stage2Oracle != None ) {
stage2Oracle.get.features.weights.read( Source.fromFile( stage2weightfile ).getLines( ) )
}
}
}
logger(0, "done")

val input = stdin.getLines.toArray
val input = if (options.contains('input)) fromFile(options('input)).getLines().toArray else stdin.getLines.toArray
val tokenized = fromFile(options('tokenized).asInstanceOf[String]).getLines/*.map(x => x)*/.toArray
val nerFile = Corpus.splitOnNewline(fromFile(options('ner).asInstanceOf[String]).getLines).toArray
val oracleData : Array[String] = if (options.contains('trainingData)) {
Expand All @@ -189,6 +216,8 @@ scala -classpath . edu.cmu.lti.nlp.amr.AMRParser --stage2-decode -w weights -l l
}
val spanF1 = F1(0,0,0)

val outStream = if (options.contains('output)) new PrintStream(options('output)) else System.out

for ((block, i) <- input.zipWithIndex) {
try {
time {
Expand All @@ -197,7 +226,7 @@ scala -classpath . edu.cmu.lti.nlp.amr.AMRParser --stage2-decode -w weights -l l
val tok = tokenized(i)
val ner = nerFile(i)
val inputGraph = if (options.contains('stage1Oracle)) { Some(AMRTrainingData(oracleData(i)).toInputGraph) } else { None }
val stage1Result = stage1.decode(new Input(inputGraph,
val stage1Result = stage1.get.decode(new Input(inputGraph,
tok.split(" "),
line.split(" "),
dependencies(i),
Expand Down Expand Up @@ -277,46 +306,50 @@ scala -classpath . edu.cmu.lti.nlp.amr.AMRParser --stage2-decode -w weights -l l
})+"\n")
}

println("# ::snt "+line)
println("# ::tok "+tok)
outStream.println("# ::snt "+line)
outStream.println("# ::tok "+tok)
val sdf = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS")
decoderResultGraph.assignOpN()
decoderResultGraph.sortRelations()
decoderResultGraph.makeIds()
println("# ::alignments "+decoderResultGraph.spans.map(_.format).mkString(" ")+" ::annotator "+VERSION+" ::date "+sdf.format(new Date))

// fire the callback if defined
if (resultHandler.isDefined) resultHandler.get.apply(decoderResultGraph)

outStream.println("# ::alignments "+decoderResultGraph.spans.map(_.format).mkString(" ")+" ::annotator "+VERSION+" ::date "+sdf.format(new Date))
if (outputFormat.contains("nodes")) {
println(decoderResultGraph.printNodes.map(x => "# ::node\t" + x).mkString("\n"))
outStream.println(decoderResultGraph.printNodes.map(x => "# ::node\t" + x).mkString("\n"))
}
if (outputFormat.contains("root")) {
println(decoderResultGraph.printRoot)
}
if (outputFormat.contains("edges") && decoderResultGraph.root.relations.size > 0) {
println(decoderResultGraph.printEdges.map(x => "# ::edge\t" + x).mkString("\n"))
outStream.println(decoderResultGraph.printEdges.map(x => "# ::edge\t" + x).mkString("\n"))
}
if (outputFormat.contains("AMR")) {
println(decoderResultGraph.prettyString(detail=1, pretty=true))
outStream.println(decoderResultGraph.prettyString(detail=1, pretty=true))
}
if (outputFormat.contains("triples")) {
println(decoderResultGraph.printTriples(detail = 1))
outStream.println(decoderResultGraph.printTriples(detail = 1))
}
println()
outStream.println()
} // time
} catch { // try
case e : Throwable => if (options.contains('ignoreParserErrors)) {
println("# ::snt "+input(i))
println("# ::tok "+tokenized(i))
outStream.println("# ::snt "+input(i))
outStream.println("# ::tok "+tokenized(i))
val sdf = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS")
println("# ::alignments 0-1|0 ::annotator "+VERSION+" ::date "+sdf.format(new Date))
println("# THERE WAS AN EXCEPTION IN THE PARSER. Returning an empty graph.")
outStream.println("# ::alignments 0-1|0 ::annotator "+VERSION+" ::date "+sdf.format(new Date))
outStream.println("# THERE WAS AN EXCEPTION IN THE PARSER. Returning an empty graph.")
if (options.contains('printStackTraceOnErrors)) {
val sw = new StringWriter()
e.printStackTrace(new PrintWriter(sw))
println(sw.toString.split("\n").map(x => "# "+x).mkString("\n"))
outStream.println(sw.toString.split("\n").map(x => "# "+x).mkString("\n"))
}
logger(-1, " ********** THERE WAS AN EXCEPTION IN THE PARSER. *********")
if (verbosity >= -1) { e.printStackTrace }
logger(-1, "Continuing. To exit on errors, please run without --ignore-parser-errors")
println(Graph.empty.prettyString(detail=1, pretty=true) + '\n')
outStream.println(Graph.empty.prettyString(detail=1, pretty=true) + '\n')
} else {
throw e
}
Expand Down
53 changes: 38 additions & 15 deletions src/AlignSpans2.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package edu.cmu.lti.nlp.amr

import java.lang.Throwable
import java.util.regex.Pattern
import scala.util.{Failure, Success, Try}
import scala.util.matching.Regex
import scala.collection.mutable.Map
import scala.collection.mutable.Set
Expand All @@ -12,13 +14,20 @@ object AlignSpans2 {
val weird_system = if ("abc".split("").toList == List("a","b","c")) {
true // on some systems "abc".split("") gives Array("a","b","c") and I have no idea why
} else {
false // usually "abc".split("") is Array("","a","b","c")
false // usually "abc".split("") is Array("","a","b","c")
}

// capture spans in a string with a marker so that you can see what they look like before and after update operations
def dumpSpans(marker: String, spans: ArrayBuffer[Span]): String = {
var sb = new StringBuilder()
for (s <- spans) sb.append(f"$marker: $s\n")
sb.toString()
}

def align(sentence: Array[String], graph: Graph) {
val stemmedSentence = sentence.map(stemmer(_))
val wordToSpan : Array[Option[Int]] = sentence.map(x => None)
logger(3, "Stemmed sentence "+stemmedSentence.toList.toString)
logger(1, "Stemmed sentence "+stemmedSentence.toList.toString)

val namedEntity = new SpanAligner(sentence, graph) {
concept = "name"
Expand Down Expand Up @@ -187,14 +196,19 @@ object AlignSpans2 {
addAllSpans(singleConcept, graph, wordToSpan, addCoRefs=false)
addAllSpans(fuzzyConcept, graph, wordToSpan, addCoRefs=false)
addAllSpans(US, graph, wordToSpan, addCoRefs=false)
try { updateSpans(namedEntityCollect, graph) } catch { case e : Throwable => Unit }
try { updateSpans(unalignedEntity, graph) } catch { case e : Throwable => Unit }
try { updateSpans(quantity, graph) } catch { case e : Throwable => Unit }
try { updateSpans(argOf, graph) } catch { case e : Throwable => Unit }
try { updateSpans(personOf, graph) } catch { case e : Throwable => Unit }
try { updateSpans(governmentOrg, graph) } catch { case e : Throwable => Unit }
try { updateSpans(polarityChild, graph) } catch { case e : Throwable => Unit }
try { updateSpans(est, graph) } catch { case e : Throwable => Unit }
val spansBefore = dumpSpans("before", graph.spans)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

love it

updateSpans(namedEntityCollect, graph)
updateSpans(unalignedEntity, graph)
updateSpans(quantity, graph)
updateSpans(argOf, graph)
updateSpans(personOf, graph)
updateSpans(governmentOrg, graph)
updateSpans(polarityChild, graph)
updateSpans(est, graph)
val spansAfter = dumpSpans("after", graph.spans)
// print out the spans for comparison before and after update
logger(1, spansBefore)
logger(1, spansAfter)
//try { updateSpans(er, graph) } catch { case e : Throwable => Unit }
//dateEntities(sentence, graph)
//namedEntities(sentence, graph)
Expand Down Expand Up @@ -231,7 +245,7 @@ object AlignSpans2 {
}

def getSpans(node: Node) : List[Span] = {
logger(2, "Processing node: " + node.concept)
logger(2, "SpanAligner.getSpans Processing node: " + node.concept)
return node match {
case Node(_,_,c,_,_,_,_,_) if ((concept.r.unapplySeq(getConcept(c)) != None) && !node.isAligned(graph)) => {
logger(2, "Matched concept regex: " + concept)
Expand Down Expand Up @@ -264,7 +278,7 @@ object AlignSpans2 {
}

def update(node: Node) {
logger(2, "SpanUpdater processing node: " + node.concept)
logger(2, "SpanUpdater.update processing node: " + node.concept)
node match {
case Node(_,_,c,_,_,_,_,_) if ((concept.r.unapplySeq(getConcept(c)) != None) && (!node.isAligned(graph) || !unalignedOnly)) => {
logger(2, "SpanUpdater matched concept regex: " + concept)
Expand Down Expand Up @@ -460,12 +474,21 @@ object AlignSpans2 {
}
}
}
// invoke the add() method on each node of the graph beginning with the root
logger(2, f"addAllSpans: concept=${f.concept}")
graph.doRecursive(add)
}

private def updateSpans(f: AlignSpans2.SpanUpdater, graph: Graph) {
graph.doRecursive(f.update)
}
private def updateSpans(f: AlignSpans2.SpanUpdater, graph: Graph) = {
try {
graph.doRecursive( f.update )
}
catch {
case e: java.lang.Throwable => {
logger( 3, f"Exception in updateSpans using ${f.getClass.getCanonicalName}: " + e.getMessage)
}
}
}

/****** </This stuff was originally in Graph> *******/

Expand Down
2 changes: 1 addition & 1 deletion src/AlignWords.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ object AlignWords {
}
}
if (!found) {
//logger(2,"CONCEPT NOT FOUND: "+node.concept+" by searching "+concept)
logger(2,"CONCEPT NOT FOUND: "+node.concept+" by searching "+concept)
}
for ((_, child) <- node.topologicalOrdering) {
alignWords(stemmedSentence, child, alignments)
Expand Down
Loading