diff --git a/corenlp/src/main/scala/org/clulab/processors/fastnlp/FastNLPProcessorWithSemanticRoles.scala b/corenlp/src/main/scala/org/clulab/processors/fastnlp/FastNLPProcessorWithSemanticRoles.scala index 662a0971c..5a713ecdd 100644 --- a/corenlp/src/main/scala/org/clulab/processors/fastnlp/FastNLPProcessorWithSemanticRoles.scala +++ b/corenlp/src/main/scala/org/clulab/processors/fastnlp/FastNLPProcessorWithSemanticRoles.scala @@ -3,6 +3,8 @@ package org.clulab.processors.fastnlp import org.clulab.dynet.{ConstEmbeddingsGlove, Utils} import org.clulab.processors.Document import org.clulab.processors.clu.CluProcessor +import org.clulab.processors.clu.backend.EmbeddingsAttachment +import org.clulab.processors.clu.backend.MetalBackend import org.clulab.processors.clu.tokenizer.TokenizerStep import org.clulab.processors.shallownlp.ShallowNLPProcessor import org.clulab.struct.GraphMap @@ -22,7 +24,7 @@ class FastNLPProcessorWithSemanticRoles(tokenizerPostProcessor:Option[TokenizerS new CluProcessor() { // Since this skips CluProcessor.srl() and goes straight for srlSentence(), there isn't // a chance to make sure CluProcessor.mtlSrla is initialized, so it is done here. - assert(this.mtlSrla != null) + assert(this.srlaBackend != null) } } @@ -34,21 +36,21 @@ class FastNLPProcessorWithSemanticRoles(tokenizerPostProcessor:Option[TokenizerS } override def srl(doc: Document): Unit = { - val embeddings = ConstEmbeddingsGlove.mkConstLookupParams(doc) + val embeddingsAttachment = new EmbeddingsAttachment(MetalBackend.mkEmbeddings(doc)) val docDate = doc.getDCT for(sent <- doc.sentences) { val words = sent.words val lemmas = sent.lemmas // The SRL model relies on NEs produced by CluProcessor, so run the NER first - val (tags, _, preds) = cluProcessor.tagSentence(words, embeddings) + val (tags, _, preds) = cluProcessor.tagSentence(words, embeddingsAttachment) val predIndexes = cluProcessor.getPredicateIndexes(preds) val tagsAsArray = tags.toArray val (entities, _) = cluProcessor.nerSentence( - words, lemmas, tagsAsArray, sent.startOffsets, sent.endOffsets, docDate, embeddings + words, lemmas, tagsAsArray, sent.startOffsets, sent.endOffsets, docDate, embeddingsAttachment ) val semanticRoles = cluProcessor.srlSentence( - words, tagsAsArray, entities, predIndexes, embeddings + words, tagsAsArray, entities, predIndexes, embeddingsAttachment ) sent.graphs += GraphMap.SEMANTIC_ROLES -> semanticRoles @@ -60,5 +62,6 @@ class FastNLPProcessorWithSemanticRoles(tokenizerPostProcessor:Option[TokenizerS sent.graphs += GraphMap.ENHANCED_SEMANTIC_ROLES -> enhancedRoles } } + embeddingsAttachment.close() } } diff --git a/main/src/main/scala/org/clulab/dynet/ConstEmbeddingsGlove.scala b/main/src/main/scala/org/clulab/dynet/ConstEmbeddingsGlove.scala index c5c67066b..ac5933103 100644 --- a/main/src/main/scala/org/clulab/dynet/ConstEmbeddingsGlove.scala +++ b/main/src/main/scala/org/clulab/dynet/ConstEmbeddingsGlove.scala @@ -7,6 +7,7 @@ import org.slf4j.{Logger, LoggerFactory} import org.clulab.utils.ConfigWithDefaults import org.clulab.utils.StringUtils +import java.io.Closeable import scala.collection.mutable class ConstEmbeddingsGlove @@ -14,7 +15,12 @@ class ConstEmbeddingsGlove /** Stores lookup parameters + the map from strings to ids */ case class ConstEmbeddingParameters(collection: ParameterCollection, lookupParameters: LookupParameter, - w2i: Map[String, Int]) + w2i: Map[String, Int]) extends Closeable { + def close(): Unit = { + lookupParameters.close() + collection.close() + } +} /** * Implements the ConstEmbeddings as a thin wrapper around WordEmbeddingMap diff --git a/main/src/main/scala/org/clulab/dynet/EmbeddingLayer.scala b/main/src/main/scala/org/clulab/dynet/EmbeddingLayer.scala index 28dda1e34..dc7cde9b4 100644 --- a/main/src/main/scala/org/clulab/dynet/EmbeddingLayer.scala +++ b/main/src/main/scala/org/clulab/dynet/EmbeddingLayer.scala @@ -1,15 +1,14 @@ package org.clulab.dynet import java.io.PrintWriter - import edu.cmu.dynet.Expression.concatenate import edu.cmu.dynet.{Dim, Expression, ExpressionVector, LookupParameter, LstmBuilder, ParameterCollection, RnnBuilder} import org.clulab.struct.Counter import org.slf4j.{Logger, LoggerFactory} import org.clulab.dynet.Utils._ import org.clulab.utils.{Configured, Serializer} - import EmbeddingLayer._ +import org.clulab.processors.clu.AnnotatedSentence import scala.util.Random diff --git a/main/src/main/scala/org/clulab/dynet/InitialLayer.scala b/main/src/main/scala/org/clulab/dynet/InitialLayer.scala index 9261a6ee9..21c076463 100644 --- a/main/src/main/scala/org/clulab/dynet/InitialLayer.scala +++ b/main/src/main/scala/org/clulab/dynet/InitialLayer.scala @@ -1,6 +1,7 @@ package org.clulab.dynet import edu.cmu.dynet.{ExpressionVector, LookupParameter} +import org.clulab.processors.clu.AnnotatedSentence /** * First layer that occurs in a sequence modeling architecture: goes from words to Expressions diff --git a/main/src/main/scala/org/clulab/dynet/Layers.scala b/main/src/main/scala/org/clulab/dynet/Layers.scala index d5a34b10f..e6f9ee1f8 100644 --- a/main/src/main/scala/org/clulab/dynet/Layers.scala +++ b/main/src/main/scala/org/clulab/dynet/Layers.scala @@ -6,6 +6,7 @@ import org.clulab.struct.Counter import org.clulab.utils.Configured import org.clulab.dynet.Utils._ import org.clulab.fatdynet.utils.Synchronizer +import org.clulab.processors.clu.AnnotatedSentence import scala.collection.mutable.ArrayBuffer diff --git a/main/src/main/scala/org/clulab/dynet/Metal.scala b/main/src/main/scala/org/clulab/dynet/Metal.scala index 5dc99cb2c..e327e99e4 100644 --- a/main/src/main/scala/org/clulab/dynet/Metal.scala +++ b/main/src/main/scala/org/clulab/dynet/Metal.scala @@ -1,7 +1,6 @@ package org.clulab.dynet import java.io.{FileWriter, PrintWriter} - import com.typesafe.config.ConfigFactory import edu.cmu.dynet.{AdamTrainer, ComputationGraph, Expression, ExpressionVector, ParameterCollection, RMSPropTrainer, SimpleSGDTrainer} import org.clulab.dynet.Utils._ @@ -14,8 +13,8 @@ import org.clulab.fatdynet.utils.Closer.AutoCloser import scala.collection.mutable.ArrayBuffer import scala.util.Random - import Metal._ +import org.clulab.processors.clu.AnnotatedSentence /** * Multi-task learning (MeTaL) for sequence modeling diff --git a/main/src/main/scala/org/clulab/dynet/MetalShell.scala b/main/src/main/scala/org/clulab/dynet/MetalShell.scala index 256510dae..b1acbe250 100644 --- a/main/src/main/scala/org/clulab/dynet/MetalShell.scala +++ b/main/src/main/scala/org/clulab/dynet/MetalShell.scala @@ -1,5 +1,6 @@ package org.clulab.dynet +import org.clulab.processors.clu.AnnotatedSentence import org.clulab.utils.Shell class MetalShell(val mtl: Metal) extends Shell { diff --git a/main/src/main/scala/org/clulab/dynet/RowReaders.scala b/main/src/main/scala/org/clulab/dynet/RowReaders.scala index 098fd61d6..7eaad30a0 100644 --- a/main/src/main/scala/org/clulab/dynet/RowReaders.scala +++ b/main/src/main/scala/org/clulab/dynet/RowReaders.scala @@ -8,16 +8,8 @@ package org.clulab.dynet import org.clulab.sequences.Row import scala.collection.mutable.ArrayBuffer - import MetalRowReader._ - -case class AnnotatedSentence(words: IndexedSeq[String], - posTags: Option[IndexedSeq[String]] = None, - neTags: Option[IndexedSeq[String]] = None, - headPositions: Option[IndexedSeq[Int]] = None) { - def indices: Range = words.indices - def size: Int = words.size -} +import org.clulab.processors.clu.AnnotatedSentence trait RowReader { /** Converts the tabular format into one or more (AnnotatedSentence, sequence of gold labels) pairs */ diff --git a/main/src/main/scala/org/clulab/processors/clu/AnnotatedSentence.scala b/main/src/main/scala/org/clulab/processors/clu/AnnotatedSentence.scala new file mode 100644 index 000000000..50c44a413 --- /dev/null +++ b/main/src/main/scala/org/clulab/processors/clu/AnnotatedSentence.scala @@ -0,0 +1,11 @@ +package org.clulab.processors.clu + +case class AnnotatedSentence( + words: IndexedSeq[String], + posTags: Option[IndexedSeq[String]] = None, + neTags: Option[IndexedSeq[String]] = None, + headPositions: Option[IndexedSeq[Int]] = None +) { + def indices: Range = words.indices + def size: Int = words.size +} diff --git a/main/src/main/scala/org/clulab/processors/clu/CluProcessor.scala b/main/src/main/scala/org/clulab/processors/clu/CluProcessor.scala index 13f13a5fb..0febc3193 100644 --- a/main/src/main/scala/org/clulab/processors/clu/CluProcessor.scala +++ b/main/src/main/scala/org/clulab/processors/clu/CluProcessor.scala @@ -9,8 +9,17 @@ import org.slf4j.{Logger, LoggerFactory} import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, ListBuffer} import CluProcessor._ -import org.clulab.dynet.{AnnotatedSentence, ConstEmbeddingParameters, ConstEmbeddingsGlove, Metal} +import org.clulab.processors.clu.backend.DepsBackend +import org.clulab.processors.clu.backend.MetalDepsBackend +import org.clulab.processors.clu.backend.MetalNerBackend +import org.clulab.processors.clu.backend.MetalPosBackend +import org.clulab.processors.clu.backend.MetalSrlaBackend +import org.clulab.processors.clu.backend.NerBackend +import org.clulab.processors.clu.backend.PosBackend +import org.clulab.processors.clu.backend.SrlaBackend import org.clulab.numeric.{NumericEntityRecognizer, setLabelsAndNorms} +import org.clulab.processors.clu.backend.EmbeddingsAttachment +import org.clulab.processors.clu.backend.MetalBackend import org.clulab.sequences.LexiconNER import org.clulab.struct.{DirectedGraph, Edge, GraphMap} import org.clulab.utils.BeforeAndAfter @@ -52,27 +61,28 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), } // one of the multi-task learning (MTL) models, which covers: POS, chunking, and SRL (predicates) - lazy val mtlPosChunkSrlp: Metal = getArgString(s"$prefix.language", Some("EN")) match { + lazy val posBackend: PosBackend = getArgString(s"$prefix.language", Some("EN")) match { case "PT" => throw new RuntimeException("PT model not trained yet") // Add PT case "ES" => throw new RuntimeException("ES model not trained yet") // Add ES - case _ => Metal(getArgString(s"$prefix.mtl-pos-chunk-srlp", Some("mtl-en-pos-chunk-srlp"))) + // Use the config to decide which one to use. + case _ => new MetalPosBackend(getArgString(s"$prefix.mtl-pos-chunk-srlp", Some("mtl-en-pos-chunk-srlp"))) } // one of the multi-task learning (MTL) models, which covers: NER - lazy val mtlNer: Metal = getArgString(s"$prefix.language", Some("EN")) match { + lazy val nerBackend: NerBackend = getArgString(s"$prefix.language", Some("EN")) match { case "PT" => throw new RuntimeException("PT model not trained yet") // Add PT case "ES" => throw new RuntimeException("ES model not trained yet") // Add ES - case _ => Metal(getArgString(s"$prefix.mtl-ner", Some("mtl-en-ner"))) + case _ => new MetalNerBackend(getArgString(s"$prefix.mtl-ner", Some("mtl-en-ner"))) } // recognizes numeric entities using Odin rules lazy val numericEntityRecognizer = new NumericEntityRecognizer // one of the multi-task learning (MTL) models, which covers: SRL (arguments) - lazy val mtlSrla: Metal = getArgString(s"$prefix.language", Some("EN")) match { + lazy val srlaBackend: SrlaBackend = getArgString(s"$prefix.language", Some("EN")) match { case "PT" => throw new RuntimeException("PT model not trained yet") // Add PT case "ES" => throw new RuntimeException("ES model not trained yet") // Add ES - case _ => Metal(getArgString(s"$prefix.mtl-srla", Some("mtl-en-srla"))) + case _ => new MetalSrlaBackend(getArgString(s"$prefix.mtl-srla", Some("mtl-en-srla"))) } /* @@ -88,10 +98,10 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), case _ => Metal(getArgString(s"$prefix.mtl-depsl", Some("mtl-en-depsl"))) } */ - lazy val mtlDeps: Metal = getArgString(s"$prefix.language", Some("EN")) match { + lazy val depsBackend: DepsBackend = getArgString(s"$prefix.language", Some("EN")) match { case "PT" => throw new RuntimeException("PT model not trained yet") // Add PT case "ES" => throw new RuntimeException("ES model not trained yet") // Add ES - case _ => Metal(getArgString(s"$prefix.mtl-deps", Some("mtl-en-deps"))) + case _ => new MetalDepsBackend(getArgString(s"$prefix.mtl-deps", Some("mtl-en-deps"))) } // Although this uses no class members, the method is sometimes called from tests @@ -149,14 +159,10 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), class PredicateAttachment(val predicates: IndexedSeq[IndexedSeq[Int]]) extends IntermediateDocumentAttachment /** Produces POS tags, chunks, and semantic role predicates for one sentence */ - def tagSentence(words: IndexedSeq[String], embeddings: ConstEmbeddingParameters): + def tagSentence(words: IndexedSeq[String], embeddingsAttachment: EmbeddingsAttachment): (IndexedSeq[String], IndexedSeq[String], IndexedSeq[String]) = { - val allLabels = mtlPosChunkSrlp.predictJointly(AnnotatedSentence(words), embeddings) - val tags = allLabels(0) - val chunks = allLabels(1) - val preds = allLabels(2) - (tags, chunks, preds) + posBackend.predict(AnnotatedSentence(words), embeddingsAttachment) } /** Produces NE labels for one sentence */ @@ -166,10 +172,10 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), startCharOffsets: Array[Int], endCharOffsets: Array[Int], docDateOpt: Option[String], - embeddings: ConstEmbeddingParameters): (IndexedSeq[String], Option[IndexedSeq[String]]) = { + embeddingsAttachment: EmbeddingsAttachment): (IndexedSeq[String], Option[IndexedSeq[String]]) = { // NER labels from the statistical model - val allLabels = mtlNer.predictJointly(AnnotatedSentence(words), embeddings) + val labels = nerBackend.predict(AnnotatedSentence(words), embeddingsAttachment) // NER labels from the custom NER val optionalNERLabels: Option[Array[String]] = { @@ -196,9 +202,9 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), } if(optionalNERLabels.isEmpty) { - (allLabels(0), None) + (labels, None) } else { - (mergeNerLabels(allLabels(0), optionalNERLabels.get), None) + (mergeNerLabels(labels, optionalNERLabels.get), None) } } @@ -239,7 +245,7 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), def parseSentence(words: IndexedSeq[String], posTags: IndexedSeq[String], nerLabels: IndexedSeq[String], - embeddings: ConstEmbeddingParameters): DirectedGraph[String] = { + embeddingsAttachment: EmbeddingsAttachment): DirectedGraph[String] = { //println(s"Words: ${words.mkString(", ")}") //println(s"Tags: ${posTags.mkString(", ")}") @@ -248,7 +254,7 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), val annotatedSentence = AnnotatedSentence(words, Some(posTags), Some(nerLabels)) - val headsAndLabels = mtlDeps.parse(annotatedSentence, embeddings) + val headsAndLabels = depsBackend.predict(annotatedSentence, embeddingsAttachment) val edges = new ListBuffer[Edge[String]]() val roots = new mutable.HashSet[Int]() @@ -323,12 +329,12 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), def srlSentence(sent: Sentence, predicateIndexes: IndexedSeq[Int], - embeddings: ConstEmbeddingParameters): DirectedGraph[String] = { + embeddingsAttachment: EmbeddingsAttachment): DirectedGraph[String] = { // the SRL models were trained using only named (CoNLL) entities, not numeric ones, so let's remove them // TODO: retrain the SRL using numeric entities too, when NumericEntityRecognizer is stable val onlyNamedLabels = removeNumericLabels(sent.entities.get) - srlSentence(sent.words, sent.tags.get, onlyNamedLabels, predicateIndexes, embeddings) + srlSentence(sent.words, sent.tags.get, onlyNamedLabels, predicateIndexes, embeddingsAttachment) } def removeNumericLabels(allLabels: Array[String]): Array[String] = { @@ -349,12 +355,12 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), posTags: IndexedSeq[String], nerLabels: IndexedSeq[String], predicateIndexes: IndexedSeq[Int], - embeddings: ConstEmbeddingParameters): DirectedGraph[String] = { + embeddingsAttachment: EmbeddingsAttachment): DirectedGraph[String] = { val edges = predicateIndexes.flatMap { pred => // SRL needs POS tags and NEs, as well as the position of the predicate val headPositions = Array.fill(words.length)(pred) val annotatedSentence = AnnotatedSentence(words, Some(posTags), Some(nerLabels), Some(headPositions)) - val argLabels = mtlSrla.predict(0, annotatedSentence, embeddings) + val argLabels = srlaBackend.predict(0, annotatedSentence, embeddingsAttachment) argLabels.zipWithIndex .filter { case (argLabel, _) => argLabel != "O" } @@ -363,14 +369,14 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), new DirectedGraph[String](edges.toList, Some(words.length)) } - private def getEmbeddings(doc: Document): ConstEmbeddingParameters = - doc.getAttachment(CONST_EMBEDDINGS_ATTACHMENT_NAME).get.asInstanceOf[EmbeddingsAttachment].embeddings + private def getEmbeddingsAttachment(doc: Document): EmbeddingsAttachment = + doc.getAttachment(CONST_EMBEDDINGS_ATTACHMENT_NAME).get.asInstanceOf[EmbeddingsAttachment] /** Part of speech tagging + chunking + SRL (predicates), jointly */ override def tagPartsOfSpeech(doc:Document) { basicSanityCheck(doc) - val embeddings = getEmbeddings(doc) + val embeddings = getEmbeddingsAttachment(doc) val predsForAllSents = new ArrayBuffer[IndexedSeq[Int]]() for(sent <- doc.sentences) { @@ -434,7 +440,7 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), // // names entities from CoNLL // - val embeddings = getEmbeddings(doc) + val embeddingsAttachment = getEmbeddingsAttachment(doc) val docDate = doc.getDCT for(sent <- doc.sentences) { val (labels, norms) = nerSentence( @@ -444,7 +450,7 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), sent.startOffsets, sent.endOffsets, docDate, - embeddings) + embeddingsAttachment) sent.entities = Some(labels.toArray) if(norms.nonEmpty) { @@ -473,7 +479,7 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), if(sentence.universalBasicDependencies.isEmpty) return origPreds if(sentence.tags.isEmpty) return origPreds - + val preds = origPreds.toSet val newPreds = new mutable.HashSet[Int]() newPreds ++= preds @@ -499,7 +505,7 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), val predicatesAttachment = doc.getAttachment(PREDICATE_ATTACHMENT_NAME) assert(predicatesAttachment.nonEmpty) assert(doc.getAttachment(CONST_EMBEDDINGS_ATTACHMENT_NAME).isDefined) - val embeddings = getEmbeddings(doc) + val embeddingsAttachment = getEmbeddingsAttachment(doc) if(doc.sentences.length > 0) { assert(doc.sentences(0).tags.nonEmpty) @@ -515,14 +521,14 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), // until later, possibly in a parallel processing situation which will result sooner or later and // under the right conditions in DyNet crashing. Therefore, make preemptive reference to mtlSrla // here so that it is sure to be initialized, regardless of the priming sentence. - assert(mtlSrla != null) + assert(srlaBackend != null) // generate SRL frames for each predicate in each sentence for(si <- predicates.indices) { val sentence = doc.sentences(si) val predicateIndexes = predicateCorrections(predicates(si), sentence) - val semanticRoles = srlSentence(sentence, predicateIndexes, embeddings) + val semanticRoles = srlSentence(sentence, predicateIndexes, embeddingsAttachment) sentence.graphs += GraphMap.SEMANTIC_ROLES -> semanticRoles @@ -554,7 +560,7 @@ class CluProcessor (val config: Config = ConfigFactory.load("cluprocessor"), assert(doc.sentences(0).entities.nonEmpty) } assert(doc.getAttachment(CONST_EMBEDDINGS_ATTACHMENT_NAME).isDefined) - val embeddings = getEmbeddings(doc) + val embeddings = getEmbeddingsAttachment(doc) for(sent <- doc.sentences) { val depGraph = parseSentence(sent.words, sent.tags.get, sent.entities.get, embeddings) @@ -742,9 +748,6 @@ object CluProcessor { } } -case class EmbeddingsAttachment(embeddings: ConstEmbeddingParameters) - extends IntermediateDocumentAttachment - class GivenConstEmbeddingsAttachment(doc: Document) extends BeforeAndAfter { def before(): Unit = GivenConstEmbeddingsAttachment.mkConstEmbeddings(doc) @@ -752,12 +755,7 @@ class GivenConstEmbeddingsAttachment(doc: Document) extends BeforeAndAfter { def after(): Unit = { val attachment = doc.getAttachment(CONST_EMBEDDINGS_ATTACHMENT_NAME).get.asInstanceOf[EmbeddingsAttachment] doc.removeAttachment(CONST_EMBEDDINGS_ATTACHMENT_NAME) - - // This is a memory management optimization. - val embeddings = attachment.embeddings - // FatDynet needs to be updated before these can be used. - embeddings.lookupParameters.close() - embeddings.collection.close() + attachment.close() } } @@ -767,8 +765,8 @@ object GivenConstEmbeddingsAttachment { // This is static so that it can be called without an object. def mkConstEmbeddings(doc: Document): Unit = { // Fetch the const embeddings from GloVe. All our models need them. - val embeddings = ConstEmbeddingsGlove.mkConstLookupParams(doc) - val attachment = EmbeddingsAttachment(embeddings) + val embeddings = MetalBackend.mkEmbeddings(doc) + val attachment = new EmbeddingsAttachment(embeddings) // Now set them as an attachment, so they are available to all downstream methods wo/ changing the API. doc.addAttachment(CONST_EMBEDDINGS_ATTACHMENT_NAME, attachment) diff --git a/main/src/main/scala/org/clulab/processors/clu/backend/CluBackend.scala b/main/src/main/scala/org/clulab/processors/clu/backend/CluBackend.scala new file mode 100644 index 000000000..074f015c7 --- /dev/null +++ b/main/src/main/scala/org/clulab/processors/clu/backend/CluBackend.scala @@ -0,0 +1,43 @@ +package org.clulab.processors.clu.backend + +import org.clulab.processors.Document +import org.clulab.processors.IntermediateDocumentAttachment +import org.clulab.processors.clu.AnnotatedSentence + +import java.io.Closeable + +case class CloseableNone() extends Closeable { + def close(): Unit = () +} + +trait CluBackend { + def mkEmbeddings(doc: Document): Closeable = CloseableNone() +} + +class EmbeddingsAttachment(protected val value: Closeable) extends IntermediateDocumentAttachment with Closeable { + // TODO: This would need to change if multiple values were to be stored. + + def get[T]: T = value.asInstanceOf[T] // Use caution! + + def close(): Unit = value.close() +} + +trait PosBackend { + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + (IndexedSeq[String], IndexedSeq[String], IndexedSeq[String]) // tags, chunks, and preds +} + +trait NerBackend { + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[String] // labels +} + +trait SrlaBackend { + def predict(taskId: Int, annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[String] // labels +} + +trait DepsBackend { + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[(Int, String)] // heads and labels +} diff --git a/main/src/main/scala/org/clulab/processors/clu/backend/MetalBackend.scala b/main/src/main/scala/org/clulab/processors/clu/backend/MetalBackend.scala new file mode 100644 index 000000000..4c63dd6ed --- /dev/null +++ b/main/src/main/scala/org/clulab/processors/clu/backend/MetalBackend.scala @@ -0,0 +1,51 @@ +package org.clulab.processors.clu.backend + +import org.clulab.dynet.ConstEmbeddingsGlove +import org.clulab.dynet.Metal +import org.clulab.processors.Document +import org.clulab.processors.clu.AnnotatedSentence + +import java.io.Closeable + +object MetalBackend extends CluBackend { + + override def mkEmbeddings(doc: Document): Closeable = { + // Fetch the const embeddings from GloVe. All our models need them. + ConstEmbeddingsGlove.mkConstLookupParams(doc) + } +} + +class MetalPosBackend(modelFilenamePrefix: String) extends PosBackend { + protected val mtlPosChunkSrlp: Metal = Metal(modelFilenamePrefix) + + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + (IndexedSeq[String], IndexedSeq[String], IndexedSeq[String]) = { + val IndexedSeq(tags, chunks, preds, _*) = mtlPosChunkSrlp.predictJointly(annotatedSentence, embeddingsAttachment.get) + + (tags, chunks, preds) + } +} + +class MetalNerBackend(modelFilenamePrefix: String) extends NerBackend { + protected val mtlNer: Metal = Metal(modelFilenamePrefix) + + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[String] = + mtlNer.predictJointly(annotatedSentence, embeddingsAttachment.get).head // labels +} + +class MetalSrlaBackend(modelFilenamePrefix: String) extends SrlaBackend { + protected val mtlSrla: Metal = Metal(modelFilenamePrefix) + + def predict(taskId: Int, annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[String] = + mtlSrla.predict(taskId, annotatedSentence, embeddingsAttachment.get) // labels +} + +class MetalDepsBackend(modelFilenamePrefix: String) extends DepsBackend { + protected val mtlDeps: Metal = Metal(modelFilenamePrefix) + + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[(Int, String)] = + mtlDeps.parse(annotatedSentence, embeddingsAttachment.get) // heads and labels +} diff --git a/main/src/main/scala/org/clulab/processors/clu/backend/OnnxBackend.scala b/main/src/main/scala/org/clulab/processors/clu/backend/OnnxBackend.scala new file mode 100644 index 000000000..7875609af --- /dev/null +++ b/main/src/main/scala/org/clulab/processors/clu/backend/OnnxBackend.scala @@ -0,0 +1,29 @@ +package org.clulab.processors.clu.backend + +import org.clulab.processors.clu.AnnotatedSentence + +object OnnxBackend extends CluBackend + +class OnnxPosBackend(modelFilenamePrefix: String) extends PosBackend { + + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + (IndexedSeq[String], IndexedSeq[String], IndexedSeq[String]) = ??? // tags, chunks, and preds +} + +class OnnxNerBackend(modelFilenamePrefix: String) extends NerBackend { + + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[String] = ??? // labels +} + +class OnnxSrlaBackend(modelFilenamePrefix: String) extends SrlaBackend { + + def predict(taskId: Int, annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[String] = ??? // labels +} + +class OnnxDepsBackend(modelFilenamePrefix: String) extends DepsBackend { + + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[(Int, String)] = ??? // heads and labels +} diff --git a/main/src/main/scala/org/clulab/processors/clu/backend/ScalaBackend.scala b/main/src/main/scala/org/clulab/processors/clu/backend/ScalaBackend.scala new file mode 100644 index 000000000..0b2115dd3 --- /dev/null +++ b/main/src/main/scala/org/clulab/processors/clu/backend/ScalaBackend.scala @@ -0,0 +1,29 @@ +package org.clulab.processors.clu.backend + +import org.clulab.processors.clu.AnnotatedSentence + +object ScalaBackend extends CluBackend + +class ScalaPosBackend(modelFilenamePrefix: String) extends PosBackend { + + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + (IndexedSeq[String], IndexedSeq[String], IndexedSeq[String]) = ??? // tags, chunks, and preds +} + +class ScalaNerBackend(modelFilenamePrefix: String) extends NerBackend { + + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[String] = ??? // labels +} + +class ScalaSrlaBackend(modelFilenamePrefix: String) extends SrlaBackend { + + def predict(taskId: Int, annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[String] = ??? // labels +} + +class ScalaDepsBackend(modelFilenamePrefix: String) extends DepsBackend { + + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[(Int, String)] = ??? // heads and labels +} diff --git a/main/src/main/scala/org/clulab/processors/clu/backend/TorchBackend.scala b/main/src/main/scala/org/clulab/processors/clu/backend/TorchBackend.scala new file mode 100644 index 000000000..f950f11a2 --- /dev/null +++ b/main/src/main/scala/org/clulab/processors/clu/backend/TorchBackend.scala @@ -0,0 +1,29 @@ +package org.clulab.processors.clu.backend + +import org.clulab.processors.clu.AnnotatedSentence + +object TorchBackend extends CluBackend + +class TorchPosBackend(modelFilenamePrefix: String) extends PosBackend { + + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + (IndexedSeq[String], IndexedSeq[String], IndexedSeq[String]) = ??? // tags, chunks, and preds +} + +class TorchNerBackend(modelFilenamePrefix: String) extends NerBackend { + + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[String] = ??? // labels +} + +class TorchSrlaBackend(modelFilenamePrefix: String) extends SrlaBackend { + + def predict(taskId: Int, annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[String] = ??? // labels +} + +class TorchDepsBackend(modelFilenamePrefix: String) extends DepsBackend { + + def predict(annotatedSentence: AnnotatedSentence, embeddingsAttachment: EmbeddingsAttachment): + IndexedSeq[(Int, String)] = ??? // heads and labels +}