From decc9b89f1583e06a2d4f88921007dada951b66f Mon Sep 17 00:00:00 2001 From: Keith Alcock Date: Fri, 10 Sep 2021 08:45:17 -0700 Subject: [PATCH 1/4] Make room for other backend (NN) implementations Add a layer of abstraction to support a TorchScript version and others --- .../FastNLPProcessorWithSemanticRoles.scala | 13 ++- .../clulab/dynet/ConstEmbeddingsGlove.scala | 8 +- .../clulab/processors/clu/CluProcessor.scala | 91 +++++++++---------- .../processors/clu/backend/CluBackend.scala | 43 +++++++++ .../processors/clu/backend/MetalBackend.scala | 51 +++++++++++ .../processors/clu/backend/ScalaBackend.scala | 29 ++++++ .../processors/clu/backend/TorchBackend.scala | 29 ++++++ 7 files changed, 212 insertions(+), 52 deletions(-) create mode 100644 main/src/main/scala/org/clulab/processors/clu/backend/CluBackend.scala create mode 100644 main/src/main/scala/org/clulab/processors/clu/backend/MetalBackend.scala create mode 100644 main/src/main/scala/org/clulab/processors/clu/backend/ScalaBackend.scala create mode 100644 main/src/main/scala/org/clulab/processors/clu/backend/TorchBackend.scala diff --git a/corenlp/src/main/scala/org/clulab/processors/fastnlp/FastNLPProcessorWithSemanticRoles.scala b/corenlp/src/main/scala/org/clulab/processors/fastnlp/FastNLPProcessorWithSemanticRoles.scala index 662a0971c..5a713ecdd 100644 --- a/corenlp/src/main/scala/org/clulab/processors/fastnlp/FastNLPProcessorWithSemanticRoles.scala +++ b/corenlp/src/main/scala/org/clulab/processors/fastnlp/FastNLPProcessorWithSemanticRoles.scala @@ -3,6 +3,8 @@ package org.clulab.processors.fastnlp import org.clulab.dynet.{ConstEmbeddingsGlove, Utils} import org.clulab.processors.Document import org.clulab.processors.clu.CluProcessor +import org.clulab.processors.clu.backend.EmbeddingsAttachment +import org.clulab.processors.clu.backend.MetalBackend import org.clulab.processors.clu.tokenizer.TokenizerStep import org.clulab.processors.shallownlp.ShallowNLPProcessor import org.clulab.struct.GraphMap @@ -22,7 +24,7 @@ class FastNLPProcessorWithSemanticRoles(tokenizerPostProcessor:Option[TokenizerS new CluProcessor() { // Since this skips CluProcessor.srl() and goes straight for srlSentence(), there isn't // a chance to make sure CluProcessor.mtlSrla is initialized, so it is done here. - assert(this.mtlSrla != null) + assert(this.srlaBackend != null) } } @@ -34,21 +36,21 @@ class FastNLPProcessorWithSemanticRoles(tokenizerPostProcessor:Option[TokenizerS } override def srl(doc: Document): Unit = { - val embeddings = ConstEmbeddingsGlove.mkConstLookupParams(doc) + val embeddingsAttachment = new EmbeddingsAttachment(MetalBackend.mkEmbeddings(doc)) val docDate = doc.getDCT for(sent <- doc.sentences) { val words = sent.words val lemmas = sent.lemmas // The SRL model relies on NEs produced by CluProcessor, so run the NER first - val (tags, _, preds) = cluProcessor.tagSentence(words, embeddings) + val (tags, _, preds) = cluProcessor.tagSentence(words, embeddingsAttachment) val predIndexes = cluProcessor.getPredicateIndexes(preds) val tagsAsArray = tags.toArray val (entities, _) = cluProcessor.nerSentence( - words, lemmas, tagsAsArray, sent.startOffsets, sent.endOffsets, docDate, embeddings + words, lemmas, tagsAsArray, sent.startOffsets, sent.endOffsets, docDate, embeddingsAttachment ) val semanticRoles = cluProcessor.srlSentence( - words, tagsAsArray, entities, predIndexes, embeddings + words, tagsAsArray, entities, predIndexes, embeddingsAttachment ) sent.graphs += GraphMap.SEMANTIC_ROLES -> semanticRoles @@ -60,5 +62,6 @@ class FastNLPProcessorWithSemanticRoles(tokenizerPostProcessor:Option[TokenizerS sent.graphs += GraphMap.ENHANCED_SEMANTIC_ROLES -> enhancedRoles } } + embeddingsAttachment.close() } } diff --git a/main/src/main/scala/org/clulab/dynet/ConstEmbeddingsGlove.scala b/main/src/main/scala/org/clulab/dynet/ConstEmbeddingsGlove.scala index c5c67066b..ac5933103 100644 --- a/main/src/main/scala/org/clulab/dynet/ConstEmbeddingsGlove.scala +++ b/main/src/main/scala/org/clulab/dynet/ConstEmbeddingsGlove.scala @@ -7,6 +7,7 @@ import org.slf4j.{Logger, LoggerFactory} import org.clulab.utils.ConfigWithDefaults import org.clulab.utils.StringUtils +import java.io.Closeable import scala.collection.mutable class ConstEmbeddingsGlove @@ -14,7 +15,12 @@ class ConstEmbeddingsGlove /** Stores lookup parameters + the map from strings to ids */ case class ConstEmbeddingParameters(collection: ParameterCollection, lookupParameters: LookupParameter, - w2i: Map[String, Int]) + w2i: Map[String, Int]) extends Closeable { + def close(): Unit = { + lookupParameters.close() + collection.close() + } +} /** * Implements the ConstEmbeddings as a thin wrapper around WordEmbeddingMap diff --git a/main/src/main/scala/org/clulab/processors/clu/CluProcessor.scala b/main/src/main/scala/org/clulab/processors/clu/CluProcessor.scala index 13f13a5fb..a5e82f953 100644 --- a/main/src/main/scala/org/clulab/processors/clu/CluProcessor.scala +++ b/main/src/main/scala/org/clulab/processors/clu/CluProcessor.scala @@ -9,8 +9,18 @@ import org.slf4j.{Logger, LoggerFactory} import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, ListBuffer} import CluProcessor._ -import org.clulab.dynet.{AnnotatedSentence, ConstEmbeddingParameters, ConstEmbeddingsGlove, Metal} +import org.clulab.processors.clu.backend.DepsBackend +import org.clulab.processors.clu.backend.MetalDepsBackend +import org.clulab.processors.clu.backend.MetalNerBackend +import org.clulab.processors.clu.backend.MetalPosBackend +import org.clulab.processors.clu.backend.MetalSrlaBackend +import org.clulab.processors.clu.backend.NerBackend +import org.clulab.processors.clu.backend.PosBackend +import org.clulab.processors.clu.backend.SrlaBackend +import org.clulab.dynet.AnnotatedSentence import org.clulab.numeric.{NumericEntityRecognizer, setLabelsAndNorms} +import org.clulab.processors.clu.backend.EmbeddingsAttachment +import org.clulab.processors.clu.backend.MetalBackend import org.clulab.sequences.LexiconNER import org.clulab.struct.{DirectedGraph, Edge, GraphMap} import org.clulab.utils.BeforeAndAfter @@ -52,27 +62,28 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), } // one of the multi-task learning (MTL) models, which covers: POS, chunking, and SRL (predicates) - lazy val mtlPosChunkSrlp: Metal = getArgString(s"$prefix.language", Some("EN")) match { + lazy val posBackend: PosBackend = getArgString(s"$prefix.language", Some("EN")) match { case "PT" => throw new RuntimeException("PT model not trained yet") // Add PT case "ES" => throw new RuntimeException("ES model not trained yet") // Add ES - case _ => Metal(getArgString(s"$prefix.mtl-pos-chunk-srlp", Some("mtl-en-pos-chunk-srlp"))) + // Use the config to decide which one to use. + case _ => new MetalPosBackend(getArgString(s"$prefix.mtl-pos-chunk-srlp", Some("mtl-en-pos-chunk-srlp"))) } // one of the multi-task learning (MTL) models, which covers: NER - lazy val mtlNer: Metal = getArgString(s"$prefix.language", Some("EN")) match { + lazy val nerBackend: NerBackend = getArgString(s"$prefix.language", Some("EN")) match { case "PT" => throw new RuntimeException("PT model not trained yet") // Add PT case "ES" => throw new RuntimeException("ES model not trained yet") // Add ES - case _ => Metal(getArgString(s"$prefix.mtl-ner", Some("mtl-en-ner"))) + case _ => new MetalNerBackend(getArgString(s"$prefix.mtl-ner", Some("mtl-en-ner"))) } // recognizes numeric entities using Odin rules lazy val numericEntityRecognizer = new NumericEntityRecognizer // one of the multi-task learning (MTL) models, which covers: SRL (arguments) - lazy val mtlSrla: Metal = getArgString(s"$prefix.language", Some("EN")) match { + lazy val srlaBackend: SrlaBackend = getArgString(s"$prefix.language", Some("EN")) match { case "PT" => throw new RuntimeException("PT model not trained yet") // Add PT case "ES" => throw new RuntimeException("ES model not trained yet") // Add ES - case _ => Metal(getArgString(s"$prefix.mtl-srla", Some("mtl-en-srla"))) + case _ => new MetalSrlaBackend(getArgString(s"$prefix.mtl-srla", Some("mtl-en-srla"))) } /* @@ -88,10 +99,10 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), case _ => Metal(getArgString(s"$prefix.mtl-depsl", Some("mtl-en-depsl"))) } */ - lazy val mtlDeps: Metal = getArgString(s"$prefix.language", Some("EN")) match { + lazy val depsBackend: DepsBackend = getArgString(s"$prefix.language", Some("EN")) match { case "PT" => throw new RuntimeException("PT model not trained yet") // Add PT case "ES" => throw new RuntimeException("ES model not trained yet") // Add ES - case _ => Metal(getArgString(s"$prefix.mtl-deps", Some("mtl-en-deps"))) + case _ => new MetalDepsBackend(getArgString(s"$prefix.mtl-deps", Some("mtl-en-deps"))) } // Although this uses no class members, the method is sometimes called from tests @@ -149,14 +160,10 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), class PredicateAttachment(val predicates: IndexedSeq[IndexedSeq[Int]]) extends IntermediateDocumentAttachment /** Produces POS tags, chunks, and semantic role predicates for one sentence */ - def tagSentence(words: IndexedSeq[String], embeddings: ConstEmbeddingParameters): + def tagSentence(words: IndexedSeq[String], embeddingsAttachment: EmbeddingsAttachment): (IndexedSeq[String], IndexedSeq[String], IndexedSeq[String]) = { - val allLabels = mtlPosChunkSrlp.predictJointly(AnnotatedSentence(words), embeddings) - val tags = allLabels(0) - val chunks = allLabels(1) - val preds = allLabels(2) - (tags, chunks, preds) + posBackend.predict(AnnotatedSentence(words), embeddingsAttachment) } /** Produces NE labels for one sentence */ @@ -166,10 +173,10 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), startCharOffsets: Array[Int], endCharOffsets: Array[Int], docDateOpt: Option[String], - embeddings: ConstEmbeddingParameters): (IndexedSeq[String], Option[IndexedSeq[String]]) = { + embeddingsAttachment: EmbeddingsAttachment): (IndexedSeq[String], Option[IndexedSeq[String]]) = { // NER labels from the statistical model - val allLabels = mtlNer.predictJointly(AnnotatedSentence(words), embeddings) + val labels = nerBackend.predict(AnnotatedSentence(words), embeddingsAttachment) // NER labels from the custom NER val optionalNERLabels: Option[Array[String]] = { @@ -196,9 +203,9 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), } if(optionalNERLabels.isEmpty) { - (allLabels(0), None) + (labels, None) } else { - (mergeNerLabels(allLabels(0), optionalNERLabels.get), None) + (mergeNerLabels(labels, optionalNERLabels.get), None) } } @@ -239,7 +246,7 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), def parseSentence(words: IndexedSeq[String], posTags: IndexedSeq[String], nerLabels: IndexedSeq[String], - embeddings: ConstEmbeddingParameters): DirectedGraph[String] = { + embeddingsAttachment: EmbeddingsAttachment): DirectedGraph[String] = { //println(s"Words: ${words.mkString(", ")}") //println(s"Tags: ${posTags.mkString(", ")}") @@ -248,7 +255,7 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), val annotatedSentence = AnnotatedSentence(words, Some(posTags), Some(nerLabels)) - val headsAndLabels = mtlDeps.parse(annotatedSentence, embeddings) + val headsAndLabels = depsBackend.predict(annotatedSentence, embeddingsAttachment) val edges = new ListBuffer[Edge[String]]() val roots = new mutable.HashSet[Int]() @@ -323,12 +330,12 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), def srlSentence(sent: Sentence, predicateIndexes: IndexedSeq[Int], - embeddings: ConstEmbeddingParameters): DirectedGraph[String] = { + embeddingsAttachment: EmbeddingsAttachment): DirectedGraph[String] = { // the SRL models were trained using only named (CoNLL) entities, not numeric ones, so let's remove them // TODO: retrain the SRL using numeric entities too, when NumericEntityRecognizer is stable val onlyNamedLabels = removeNumericLabels(sent.entities.get) - srlSentence(sent.words, sent.tags.get, onlyNamedLabels, predicateIndexes, embeddings) + srlSentence(sent.words, sent.tags.get, onlyNamedLabels, predicateIndexes, embeddingsAttachment) } def removeNumericLabels(allLabels: Array[String]): Array[String] = { @@ -349,12 +356,12 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), posTags: IndexedSeq[String], nerLabels: IndexedSeq[String], predicateIndexes: IndexedSeq[Int], - embeddings: ConstEmbeddingParameters): DirectedGraph[String] = { + embeddingsAttachment: EmbeddingsAttachment): DirectedGraph[String] = { val edges = predicateIndexes.flatMap { pred => // SRL needs POS tags and NEs, as well as the position of the predicate val headPositions = Array.fill(words.length)(pred) val annotatedSentence = AnnotatedSentence(words, Some(posTags), Some(nerLabels), Some(headPositions)) - val argLabels = mtlSrla.predict(0, annotatedSentence, embeddings) + val argLabels = srlaBackend.predict(0, annotatedSentence, embeddingsAttachment) argLabels.zipWithIndex .filter { case (argLabel, _) => argLabel != "O" } @@ -363,14 +370,14 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), new DirectedGraph[String](edges.toList, Some(words.length)) } - private def getEmbeddings(doc: Document): ConstEmbeddingParameters = - doc.getAttachment(CONST_EMBEDDINGS_ATTACHMENT_NAME).get.asInstanceOf[EmbeddingsAttachment].embeddings + private def getEmbeddingsAttachment(doc: Document): EmbeddingsAttachment = + doc.getAttachment(CONST_EMBEDDINGS_ATTACHMENT_NAME).get.asInstanceOf[EmbeddingsAttachment] /** Part of speech tagging + chunking + SRL (predicates), jointly */ override def tagPartsOfSpeech(doc:Document) { basicSanityCheck(doc) - val embeddings = getEmbeddings(doc) + val embeddings = getEmbeddingsAttachment(doc) val predsForAllSents = new ArrayBuffer[IndexedSeq[Int]]() for(sent <- doc.sentences) { @@ -434,7 +441,7 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), // // names entities from CoNLL // - val embeddings = getEmbeddings(doc) + val embeddingsAttachment = getEmbeddingsAttachment(doc) val docDate = doc.getDCT for(sent <- doc.sentences) { val (labels, norms) = nerSentence( @@ -444,7 +451,7 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), sent.startOffsets, sent.endOffsets, docDate, - embeddings) + embeddingsAttachment) sent.entities = Some(labels.toArray) if(norms.nonEmpty) { @@ -473,7 +480,7 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), if(sentence.universalBasicDependencies.isEmpty) return origPreds if(sentence.tags.isEmpty) return origPreds - + val preds = origPreds.toSet val newPreds = new mutable.HashSet[Int]() newPreds ++= preds @@ -499,7 +506,7 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), val predicatesAttachment = doc.getAttachment(PREDICATE_ATTACHMENT_NAME) assert(predicatesAttachment.nonEmpty) assert(doc.getAttachment(CONST_EMBEDDINGS_ATTACHMENT_NAME).isDefined) - val embeddings = getEmbeddings(doc) + val embeddingsAttachment = getEmbeddingsAttachment(doc) if(doc.sentences.length > 0) { assert(doc.sentences(0).tags.nonEmpty) @@ -515,14 +522,14 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), // until later, possibly in a parallel processing situation which will result sooner or later and // under the right conditions in DyNet crashing. Therefore, make preemptive reference to mtlSrla // here so that it is sure to be initialized, regardless of the priming sentence. - assert(mtlSrla != null) + assert(srlaBackend != null) // generate SRL frames for each predicate in each sentence for(si <- predicates.indices) { val sentence = doc.sentences(si) val predicateIndexes = predicateCorrections(predicates(si), sentence) - val semanticRoles = srlSentence(sentence, predicateIndexes, embeddings) + val semanticRoles = srlSentence(sentence, predicateIndexes, embeddingsAttachment) sentence.graphs += GraphMap.SEMANTIC_ROLES -> semanticRoles @@ -554,7 +561,7 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), assert(doc.sentences(0).entities.nonEmpty) } assert(doc.getAttachment(CONST_EMBEDDINGS_ATTACHMENT_NAME).isDefined) - val embeddings = getEmbeddings(doc) + val embeddings = getEmbeddingsAttachment(doc) for(sent <- doc.sentences) { val depGraph = parseSentence(sent.words, sent.tags.get, sent.entities.get, embeddings) @@ -742,9 +749,6 @@ object CluProcessor { } } -case class EmbeddingsAttachment(embeddings: ConstEmbeddingParameters) - extends IntermediateDocumentAttachment - class GivenConstEmbeddingsAttachment(doc: Document) extends BeforeAndAfter { def before(): Unit = GivenConstEmbeddingsAttachment.mkConstEmbeddings(doc) @@ -752,12 +756,7 @@ class GivenConstEmbeddingsAttachment(doc: Document) extends BeforeAndAfter { def after(): Unit = { val attachment = doc.getAttachment(CONST_EMBEDDINGS_ATTACHMENT_NAME).get.asInstanceOf[EmbeddingsAttachment] doc.removeAttachment(CONST_EMBEDDINGS_ATTACHMENT_NAME) - - // This is a memory management optimization. - val embeddings = attachment.embeddings - // FatDynet needs to be updated before these can be used. - embeddings.lookupParameters.close() - embeddings.collection.close() + attachment.close() } } @@ -767,8 +766,8 @@ object GivenConstEmbeddingsAttachment { // This is static so that it can be called without an object. def mkConstEmbeddings(doc: Document): Unit = { // Fetch the const embeddings from GloVe. All our models need them. - val embeddings = ConstEmbeddingsGlove.mkConstLookupParams(doc) - val attachment = EmbeddingsAttachment(embeddings) + val embeddings = MetalBackend.mkEmbeddings(doc) + val attachment = new EmbeddingsAttachment(embeddings) // Now set them as an attachment, so they are available to all downstream methods wo/ changing the API. doc.addAttachment(CONST_EMBEDDINGS_ATTACHMENT_NAME, attachment) diff --git a/main/src/main/scala/org/clulab/processors/clu/backend/CluBackend.scala b/main/src/main/scala/org/clulab/processors/clu/backend/CluBackend.scala new file mode 100644 index 000000000..9cc28822a --- /dev/null +++ b/main/src/main/scala/org/clulab/processors/clu/backend/CluBackend.scala @@ -0,0 +1,43 @@ +package org.clulab.processors.clu.backend + +import org.clulab.dynet.AnnotatedSentence +import org.clulab.processors.Document +import org.clulab.processors.IntermediateDocumentAttachment + +import java.io.Closeable + +case class CloseableNone() extends Closeable { + def close(): Unit = () +} + +trait CluBackend { + def mkEmbeddings(doc: Document): Closeable = CloseableNone() +} + +class EmbeddingsAttachment(protected val value: Closeable) extends IntermediateDocumentAttachment with Closeable { + // TODO: This would need to change if multiple values were to be stored. + + def get[T]: T = value.asInstanceOf[T] // Use caution! + + def close(): Unit = value.close() +} + +trait PosBackend { + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + (IndexedSeq[String], IndexedSeq[String], IndexedSeq[String]) // tags, chunks, and preds +} + +trait NerBackend { + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[String] // labels +} + +trait SrlaBackend { + def predict(taskId: Int, annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[String] // labels +} + +trait DepsBackend { + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[(Int, String)] // heads and labels +} diff --git a/main/src/main/scala/org/clulab/processors/clu/backend/MetalBackend.scala b/main/src/main/scala/org/clulab/processors/clu/backend/MetalBackend.scala new file mode 100644 index 000000000..8005a2f6f --- /dev/null +++ b/main/src/main/scala/org/clulab/processors/clu/backend/MetalBackend.scala @@ -0,0 +1,51 @@ +package org.clulab.processors.clu.backend + +import org.clulab.dynet.AnnotatedSentence +import org.clulab.dynet.ConstEmbeddingsGlove +import org.clulab.dynet.Metal +import org.clulab.processors.Document + +import java.io.Closeable + +object MetalBackend extends CluBackend { + + override def mkEmbeddings(doc: Document): Closeable = { + // Fetch the const embeddings from GloVe. All our models need them. + ConstEmbeddingsGlove.mkConstLookupParams(doc) + } +} + +class MetalPosBackend(modelFilenamePrefix: String) extends PosBackend { + lazy val mtlPosChunkSrlp: Metal = Metal(modelFilenamePrefix) + + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + (IndexedSeq[String], IndexedSeq[String], IndexedSeq[String]) = { + val IndexedSeq(tags, chunks, preds, _*) = mtlPosChunkSrlp.predictJointly(annotatedSentence, embeddingsAttachment.get) + + (tags, chunks, preds) + } +} + +class MetalNerBackend(modelFilenamePrefix: String) extends NerBackend { + lazy val mtlNer: Metal = Metal(modelFilenamePrefix) + + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[String] = + mtlNer.predictJointly(annotatedSentence, embeddingsAttachment.get).head // labels +} + +class MetalSrlaBackend(modelFilenamePrefix: String) extends SrlaBackend { + lazy val mtlSrla: Metal = Metal(modelFilenamePrefix) + + def predict(taskId: Int, annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[String] = + mtlSrla.predict(taskId, annotatedSentence, embeddingsAttachment.get) // labels +} + +class MetalDepsBackend(modelFilenamePrefix: String) extends DepsBackend { + lazy val mtlDeps: Metal = Metal(modelFilenamePrefix) + + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[(Int, String)] = + mtlDeps.parse(annotatedSentence, embeddingsAttachment.get) // heads and labels +} diff --git a/main/src/main/scala/org/clulab/processors/clu/backend/ScalaBackend.scala b/main/src/main/scala/org/clulab/processors/clu/backend/ScalaBackend.scala new file mode 100644 index 000000000..bfdf57a36 --- /dev/null +++ b/main/src/main/scala/org/clulab/processors/clu/backend/ScalaBackend.scala @@ -0,0 +1,29 @@ +package org.clulab.processors.clu.backend + +import org.clulab.dynet.AnnotatedSentence + +object ScalaBackend extends CluBackend + +class ScalaPosBackend(modelFilenamePrefix: String) extends PosBackend { + + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + (IndexedSeq[String], IndexedSeq[String], IndexedSeq[String]) = ??? // tags, chunks, and preds +} + +class ScalaNerBackend(modelFilenamePrefix: String) extends NerBackend { + + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[String] = ??? // labels +} + +class ScalaSrlaBackend(modelFilenamePrefix: String) extends SrlaBackend { + + def predict(taskId: Int, annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[String] = ??? // labels +} + +class ScalaDepsBackend(modelFilenamePrefix: String) extends DepsBackend { + + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[(Int, String)] = ??? // heads and labels +} diff --git a/main/src/main/scala/org/clulab/processors/clu/backend/TorchBackend.scala b/main/src/main/scala/org/clulab/processors/clu/backend/TorchBackend.scala new file mode 100644 index 000000000..1f5a0e5aa --- /dev/null +++ b/main/src/main/scala/org/clulab/processors/clu/backend/TorchBackend.scala @@ -0,0 +1,29 @@ +package org.clulab.processors.clu.backend + +import org.clulab.dynet.AnnotatedSentence + +object TorchBackend extends CluBackend + +class TorchPosBackend(modelFilenamePrefix: String) extends PosBackend { + + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + (IndexedSeq[String], IndexedSeq[String], IndexedSeq[String]) = ??? // tags, chunks, and preds +} + +class TorchNerBackend(modelFilenamePrefix: String) extends NerBackend { + + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[String] = ??? // labels +} + +class TorchSrlaBackend(modelFilenamePrefix: String) extends SrlaBackend { + + def predict(taskId: Int, annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[String] = ??? // labels +} + +class TorchDepsBackend(modelFilenamePrefix: String) extends DepsBackend { + + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[(Int, String)] = ??? // heads and labels +} From e618b9e50497e48fe856b3dca71f84c6747ba14c Mon Sep 17 00:00:00 2001 From: Keith Alcock Date: Fri, 10 Sep 2021 09:04:06 -0700 Subject: [PATCH 2/4] Avoid double laziness --- .../org/clulab/processors/clu/backend/MetalBackend.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/main/src/main/scala/org/clulab/processors/clu/backend/MetalBackend.scala b/main/src/main/scala/org/clulab/processors/clu/backend/MetalBackend.scala index 8005a2f6f..3d656be0a 100644 --- a/main/src/main/scala/org/clulab/processors/clu/backend/MetalBackend.scala +++ b/main/src/main/scala/org/clulab/processors/clu/backend/MetalBackend.scala @@ -16,7 +16,7 @@ object MetalBackend extends CluBackend { } class MetalPosBackend(modelFilenamePrefix: String) extends PosBackend { - lazy val mtlPosChunkSrlp: Metal = Metal(modelFilenamePrefix) + protected val mtlPosChunkSrlp: Metal = Metal(modelFilenamePrefix) def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): (IndexedSeq[String], IndexedSeq[String], IndexedSeq[String]) = { @@ -27,7 +27,7 @@ class MetalPosBackend(modelFilenamePrefix: String) extends PosBackend { } class MetalNerBackend(modelFilenamePrefix: String) extends NerBackend { - lazy val mtlNer: Metal = Metal(modelFilenamePrefix) + protected val mtlNer: Metal = Metal(modelFilenamePrefix) def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): IndexedSeq[String] = @@ -35,7 +35,7 @@ class MetalNerBackend(modelFilenamePrefix: String) extends NerBackend { } class MetalSrlaBackend(modelFilenamePrefix: String) extends SrlaBackend { - lazy val mtlSrla: Metal = Metal(modelFilenamePrefix) + protected val mtlSrla: Metal = Metal(modelFilenamePrefix) def predict(taskId: Int, annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): IndexedSeq[String] = @@ -43,7 +43,7 @@ class MetalSrlaBackend(modelFilenamePrefix: String) extends SrlaBackend { } class MetalDepsBackend(modelFilenamePrefix: String) extends DepsBackend { - lazy val mtlDeps: Metal = Metal(modelFilenamePrefix) + protected val mtlDeps: Metal = Metal(modelFilenamePrefix) def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): IndexedSeq[(Int, String)] = From 008580d2a97ffb5f1fe9c962b54aa0947a4040e6 Mon Sep 17 00:00:00 2001 From: Keith Alcock Date: Tue, 5 Oct 2021 16:40:35 -0700 Subject: [PATCH 3/4] Move AnnotatedSentence --- .../main/scala/org/clulab/dynet/EmbeddingLayer.scala | 3 +-- .../main/scala/org/clulab/dynet/InitialLayer.scala | 1 + main/src/main/scala/org/clulab/dynet/Layers.scala | 1 + main/src/main/scala/org/clulab/dynet/Metal.scala | 3 +-- main/src/main/scala/org/clulab/dynet/MetalShell.scala | 1 + main/src/main/scala/org/clulab/dynet/RowReaders.scala | 10 +--------- .../org/clulab/processors/clu/AnnotatedSentence.scala | 11 +++++++++++ .../org/clulab/processors/clu/CluProcessor.scala | 1 - .../clulab/processors/clu/backend/CluBackend.scala | 2 +- .../clulab/processors/clu/backend/MetalBackend.scala | 2 +- .../clulab/processors/clu/backend/TorchBackend.scala | 2 +- 11 files changed, 20 insertions(+), 17 deletions(-) create mode 100644 main/src/main/scala/org/clulab/processors/clu/AnnotatedSentence.scala diff --git a/main/src/main/scala/org/clulab/dynet/EmbeddingLayer.scala b/main/src/main/scala/org/clulab/dynet/EmbeddingLayer.scala index 28dda1e34..dc7cde9b4 100644 --- a/main/src/main/scala/org/clulab/dynet/EmbeddingLayer.scala +++ b/main/src/main/scala/org/clulab/dynet/EmbeddingLayer.scala @@ -1,15 +1,14 @@ package org.clulab.dynet import java.io.PrintWriter - import edu.cmu.dynet.Expression.concatenate import edu.cmu.dynet.{Dim, Expression, ExpressionVector, LookupParameter, LstmBuilder, ParameterCollection, RnnBuilder} import org.clulab.struct.Counter import org.slf4j.{Logger, LoggerFactory} import org.clulab.dynet.Utils._ import org.clulab.utils.{Configured, Serializer} - import EmbeddingLayer._ +import org.clulab.processors.clu.AnnotatedSentence import scala.util.Random diff --git a/main/src/main/scala/org/clulab/dynet/InitialLayer.scala b/main/src/main/scala/org/clulab/dynet/InitialLayer.scala index 9261a6ee9..21c076463 100644 --- a/main/src/main/scala/org/clulab/dynet/InitialLayer.scala +++ b/main/src/main/scala/org/clulab/dynet/InitialLayer.scala @@ -1,6 +1,7 @@ package org.clulab.dynet import edu.cmu.dynet.{ExpressionVector, LookupParameter} +import org.clulab.processors.clu.AnnotatedSentence /** * First layer that occurs in a sequence modeling architecture: goes from words to Expressions diff --git a/main/src/main/scala/org/clulab/dynet/Layers.scala b/main/src/main/scala/org/clulab/dynet/Layers.scala index d5a34b10f..e6f9ee1f8 100644 --- a/main/src/main/scala/org/clulab/dynet/Layers.scala +++ b/main/src/main/scala/org/clulab/dynet/Layers.scala @@ -6,6 +6,7 @@ import org.clulab.struct.Counter import org.clulab.utils.Configured import org.clulab.dynet.Utils._ import org.clulab.fatdynet.utils.Synchronizer +import org.clulab.processors.clu.AnnotatedSentence import scala.collection.mutable.ArrayBuffer diff --git a/main/src/main/scala/org/clulab/dynet/Metal.scala b/main/src/main/scala/org/clulab/dynet/Metal.scala index 5dc99cb2c..e327e99e4 100644 --- a/main/src/main/scala/org/clulab/dynet/Metal.scala +++ b/main/src/main/scala/org/clulab/dynet/Metal.scala @@ -1,7 +1,6 @@ package org.clulab.dynet import java.io.{FileWriter, PrintWriter} - import com.typesafe.config.ConfigFactory import edu.cmu.dynet.{AdamTrainer, ComputationGraph, Expression, ExpressionVector, ParameterCollection, RMSPropTrainer, SimpleSGDTrainer} import org.clulab.dynet.Utils._ @@ -14,8 +13,8 @@ import org.clulab.fatdynet.utils.Closer.AutoCloser import scala.collection.mutable.ArrayBuffer import scala.util.Random - import Metal._ +import org.clulab.processors.clu.AnnotatedSentence /** * Multi-task learning (MeTaL) for sequence modeling diff --git a/main/src/main/scala/org/clulab/dynet/MetalShell.scala b/main/src/main/scala/org/clulab/dynet/MetalShell.scala index 256510dae..b1acbe250 100644 --- a/main/src/main/scala/org/clulab/dynet/MetalShell.scala +++ b/main/src/main/scala/org/clulab/dynet/MetalShell.scala @@ -1,5 +1,6 @@ package org.clulab.dynet +import org.clulab.processors.clu.AnnotatedSentence import org.clulab.utils.Shell class MetalShell(val mtl: Metal) extends Shell { diff --git a/main/src/main/scala/org/clulab/dynet/RowReaders.scala b/main/src/main/scala/org/clulab/dynet/RowReaders.scala index 098fd61d6..7eaad30a0 100644 --- a/main/src/main/scala/org/clulab/dynet/RowReaders.scala +++ b/main/src/main/scala/org/clulab/dynet/RowReaders.scala @@ -8,16 +8,8 @@ package org.clulab.dynet import org.clulab.sequences.Row import scala.collection.mutable.ArrayBuffer - import MetalRowReader._ - -case class AnnotatedSentence(words: IndexedSeq[String], - posTags: Option[IndexedSeq[String]] = None, - neTags: Option[IndexedSeq[String]] = None, - headPositions: Option[IndexedSeq[Int]] = None) { - def indices: Range = words.indices - def size: Int = words.size -} +import org.clulab.processors.clu.AnnotatedSentence trait RowReader { /** Converts the tabular format into one or more (AnnotatedSentence, sequence of gold labels) pairs */ diff --git a/main/src/main/scala/org/clulab/processors/clu/AnnotatedSentence.scala b/main/src/main/scala/org/clulab/processors/clu/AnnotatedSentence.scala new file mode 100644 index 000000000..50c44a413 --- /dev/null +++ b/main/src/main/scala/org/clulab/processors/clu/AnnotatedSentence.scala @@ -0,0 +1,11 @@ +package org.clulab.processors.clu + +case class AnnotatedSentence( + words: IndexedSeq[String], + posTags: Option[IndexedSeq[String]] = None, + neTags: Option[IndexedSeq[String]] = None, + headPositions: Option[IndexedSeq[Int]] = None +) { + def indices: Range = words.indices + def size: Int = words.size +} diff --git a/main/src/main/scala/org/clulab/processors/clu/CluProcessor.scala b/main/src/main/scala/org/clulab/processors/clu/CluProcessor.scala index a5e82f953..0febc3193 100644 --- a/main/src/main/scala/org/clulab/processors/clu/CluProcessor.scala +++ b/main/src/main/scala/org/clulab/processors/clu/CluProcessor.scala @@ -17,7 +17,6 @@ import org.clulab.processors.clu.backend.MetalSrlaBackend import org.clulab.processors.clu.backend.NerBackend import org.clulab.processors.clu.backend.PosBackend import org.clulab.processors.clu.backend.SrlaBackend -import org.clulab.dynet.AnnotatedSentence import org.clulab.numeric.{NumericEntityRecognizer, setLabelsAndNorms} import org.clulab.processors.clu.backend.EmbeddingsAttachment import org.clulab.processors.clu.backend.MetalBackend diff --git a/main/src/main/scala/org/clulab/processors/clu/backend/CluBackend.scala b/main/src/main/scala/org/clulab/processors/clu/backend/CluBackend.scala index 9cc28822a..074f015c7 100644 --- a/main/src/main/scala/org/clulab/processors/clu/backend/CluBackend.scala +++ b/main/src/main/scala/org/clulab/processors/clu/backend/CluBackend.scala @@ -1,8 +1,8 @@ package org.clulab.processors.clu.backend -import org.clulab.dynet.AnnotatedSentence import org.clulab.processors.Document import org.clulab.processors.IntermediateDocumentAttachment +import org.clulab.processors.clu.AnnotatedSentence import java.io.Closeable diff --git a/main/src/main/scala/org/clulab/processors/clu/backend/MetalBackend.scala b/main/src/main/scala/org/clulab/processors/clu/backend/MetalBackend.scala index 3d656be0a..4c63dd6ed 100644 --- a/main/src/main/scala/org/clulab/processors/clu/backend/MetalBackend.scala +++ b/main/src/main/scala/org/clulab/processors/clu/backend/MetalBackend.scala @@ -1,9 +1,9 @@ package org.clulab.processors.clu.backend -import org.clulab.dynet.AnnotatedSentence import org.clulab.dynet.ConstEmbeddingsGlove import org.clulab.dynet.Metal import org.clulab.processors.Document +import org.clulab.processors.clu.AnnotatedSentence import java.io.Closeable diff --git a/main/src/main/scala/org/clulab/processors/clu/backend/TorchBackend.scala b/main/src/main/scala/org/clulab/processors/clu/backend/TorchBackend.scala index 1f5a0e5aa..f950f11a2 100644 --- a/main/src/main/scala/org/clulab/processors/clu/backend/TorchBackend.scala +++ b/main/src/main/scala/org/clulab/processors/clu/backend/TorchBackend.scala @@ -1,6 +1,6 @@ package org.clulab.processors.clu.backend -import org.clulab.dynet.AnnotatedSentence +import org.clulab.processors.clu.AnnotatedSentence object TorchBackend extends CluBackend From 4f8cd3a3eab992b39bc36dc8bf8b2b76c2e6ed47 Mon Sep 17 00:00:00 2001 From: Keith Alcock Date: Tue, 5 Oct 2021 16:46:31 -0700 Subject: [PATCH 4/4] Add Onnx backend --- .../processors/clu/backend/OnnxBackend.scala | 29 +++++++++++++++++++ .../processors/clu/backend/ScalaBackend.scala | 2 +- 2 files changed, 30 insertions(+), 1 deletion(-) create mode 100644 main/src/main/scala/org/clulab/processors/clu/backend/OnnxBackend.scala diff --git a/main/src/main/scala/org/clulab/processors/clu/backend/OnnxBackend.scala b/main/src/main/scala/org/clulab/processors/clu/backend/OnnxBackend.scala new file mode 100644 index 000000000..7875609af --- /dev/null +++ b/main/src/main/scala/org/clulab/processors/clu/backend/OnnxBackend.scala @@ -0,0 +1,29 @@ +package org.clulab.processors.clu.backend + +import org.clulab.processors.clu.AnnotatedSentence + +object OnnxBackend extends CluBackend + +class OnnxPosBackend(modelFilenamePrefix: String) extends PosBackend { + + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + (IndexedSeq[String], IndexedSeq[String], IndexedSeq[String]) = ??? // tags, chunks, and preds +} + +class OnnxNerBackend(modelFilenamePrefix: String) extends NerBackend { + + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[String] = ??? // labels +} + +class OnnxSrlaBackend(modelFilenamePrefix: String) extends SrlaBackend { + + def predict(taskId: Int, annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[String] = ??? // labels +} + +class OnnxDepsBackend(modelFilenamePrefix: String) extends DepsBackend { + + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[(Int, String)] = ??? // heads and labels +} diff --git a/main/src/main/scala/org/clulab/processors/clu/backend/ScalaBackend.scala b/main/src/main/scala/org/clulab/processors/clu/backend/ScalaBackend.scala index bfdf57a36..0b2115dd3 100644 --- a/main/src/main/scala/org/clulab/processors/clu/backend/ScalaBackend.scala +++ b/main/src/main/scala/org/clulab/processors/clu/backend/ScalaBackend.scala @@ -1,6 +1,6 @@ package org.clulab.processors.clu.backend -import org.clulab.dynet.AnnotatedSentence +import org.clulab.processors.clu.AnnotatedSentence object ScalaBackend extends CluBackend