Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make room for other backend (NN) implementations #549

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package org.clulab.processors.fastnlp
import org.clulab.dynet.{ConstEmbeddingsGlove, Utils}
import org.clulab.processors.Document
import org.clulab.processors.clu.CluProcessor
import org.clulab.processors.clu.backend.EmbeddingsAttachment
import org.clulab.processors.clu.backend.MetalBackend
import org.clulab.processors.clu.tokenizer.TokenizerStep
import org.clulab.processors.shallownlp.ShallowNLPProcessor
import org.clulab.struct.GraphMap
Expand All @@ -22,7 +24,7 @@ class FastNLPProcessorWithSemanticRoles(tokenizerPostProcessor:Option[TokenizerS
new CluProcessor() {
// Since this skips CluProcessor.srl() and goes straight for srlSentence(), there isn't
// a chance to make sure CluProcessor.mtlSrla is initialized, so it is done here.
assert(this.mtlSrla != null)
assert(this.srlaBackend != null)
}
}

Expand All @@ -34,21 +36,21 @@ class FastNLPProcessorWithSemanticRoles(tokenizerPostProcessor:Option[TokenizerS
}

override def srl(doc: Document): Unit = {
val embeddings = ConstEmbeddingsGlove.mkConstLookupParams(doc)
val embeddingsAttachment = new EmbeddingsAttachment(MetalBackend.mkEmbeddings(doc))
val docDate = doc.getDCT
for(sent <- doc.sentences) {
val words = sent.words
val lemmas = sent.lemmas

// The SRL model relies on NEs produced by CluProcessor, so run the NER first
val (tags, _, preds) = cluProcessor.tagSentence(words, embeddings)
val (tags, _, preds) = cluProcessor.tagSentence(words, embeddingsAttachment)
val predIndexes = cluProcessor.getPredicateIndexes(preds)
val tagsAsArray = tags.toArray
val (entities, _) = cluProcessor.nerSentence(
words, lemmas, tagsAsArray, sent.startOffsets, sent.endOffsets, docDate, embeddings
words, lemmas, tagsAsArray, sent.startOffsets, sent.endOffsets, docDate, embeddingsAttachment
)
val semanticRoles = cluProcessor.srlSentence(
words, tagsAsArray, entities, predIndexes, embeddings
words, tagsAsArray, entities, predIndexes, embeddingsAttachment
)

sent.graphs += GraphMap.SEMANTIC_ROLES -> semanticRoles
Expand All @@ -60,5 +62,6 @@ class FastNLPProcessorWithSemanticRoles(tokenizerPostProcessor:Option[TokenizerS
sent.graphs += GraphMap.ENHANCED_SEMANTIC_ROLES -> enhancedRoles
}
}
embeddingsAttachment.close()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,20 @@ import org.slf4j.{Logger, LoggerFactory}
import org.clulab.utils.ConfigWithDefaults
import org.clulab.utils.StringUtils

import java.io.Closeable
import scala.collection.mutable

class ConstEmbeddingsGlove

/** Stores lookup parameters + the map from strings to ids */
case class ConstEmbeddingParameters(collection: ParameterCollection,
lookupParameters: LookupParameter,
w2i: Map[String, Int])
w2i: Map[String, Int]) extends Closeable {
def close(): Unit = {
lookupParameters.close()
collection.close()
}
}

/**
* Implements the ConstEmbeddings as a thin wrapper around WordEmbeddingMap
Expand Down
3 changes: 1 addition & 2 deletions main/src/main/scala/org/clulab/dynet/EmbeddingLayer.scala
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
package org.clulab.dynet

import java.io.PrintWriter

import edu.cmu.dynet.Expression.concatenate
import edu.cmu.dynet.{Dim, Expression, ExpressionVector, LookupParameter, LstmBuilder, ParameterCollection, RnnBuilder}
import org.clulab.struct.Counter
import org.slf4j.{Logger, LoggerFactory}
import org.clulab.dynet.Utils._
import org.clulab.utils.{Configured, Serializer}

import EmbeddingLayer._
import org.clulab.processors.clu.AnnotatedSentence

import scala.util.Random

Expand Down
1 change: 1 addition & 0 deletions main/src/main/scala/org/clulab/dynet/InitialLayer.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.clulab.dynet

import edu.cmu.dynet.{ExpressionVector, LookupParameter}
import org.clulab.processors.clu.AnnotatedSentence

/**
* First layer that occurs in a sequence modeling architecture: goes from words to Expressions
Expand Down
1 change: 1 addition & 0 deletions main/src/main/scala/org/clulab/dynet/Layers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import org.clulab.struct.Counter
import org.clulab.utils.Configured
import org.clulab.dynet.Utils._
import org.clulab.fatdynet.utils.Synchronizer
import org.clulab.processors.clu.AnnotatedSentence

import scala.collection.mutable.ArrayBuffer

Expand Down
3 changes: 1 addition & 2 deletions main/src/main/scala/org/clulab/dynet/Metal.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package org.clulab.dynet

import java.io.{FileWriter, PrintWriter}

import com.typesafe.config.ConfigFactory
import edu.cmu.dynet.{AdamTrainer, ComputationGraph, Expression, ExpressionVector, ParameterCollection, RMSPropTrainer, SimpleSGDTrainer}
import org.clulab.dynet.Utils._
Expand All @@ -14,8 +13,8 @@ import org.clulab.fatdynet.utils.Closer.AutoCloser

import scala.collection.mutable.ArrayBuffer
import scala.util.Random

import Metal._
import org.clulab.processors.clu.AnnotatedSentence

/**
* Multi-task learning (MeTaL) for sequence modeling
Expand Down
1 change: 1 addition & 0 deletions main/src/main/scala/org/clulab/dynet/MetalShell.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.clulab.dynet

import org.clulab.processors.clu.AnnotatedSentence
import org.clulab.utils.Shell

class MetalShell(val mtl: Metal) extends Shell {
Expand Down
10 changes: 1 addition & 9 deletions main/src/main/scala/org/clulab/dynet/RowReaders.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,8 @@ package org.clulab.dynet
import org.clulab.sequences.Row

import scala.collection.mutable.ArrayBuffer

import MetalRowReader._

case class AnnotatedSentence(words: IndexedSeq[String],
posTags: Option[IndexedSeq[String]] = None,
neTags: Option[IndexedSeq[String]] = None,
headPositions: Option[IndexedSeq[Int]] = None) {
def indices: Range = words.indices
def size: Int = words.size
}
import org.clulab.processors.clu.AnnotatedSentence

trait RowReader {
/** Converts the tabular format into one or more (AnnotatedSentence, sequence of gold labels) pairs */
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package org.clulab.processors.clu

case class AnnotatedSentence(
words: IndexedSeq[String],
posTags: Option[IndexedSeq[String]] = None,
neTags: Option[IndexedSeq[String]] = None,
headPositions: Option[IndexedSeq[Int]] = None
) {
def indices: Range = words.indices
def size: Int = words.size
}
Loading