diff --git a/.gitignore b/.gitignore index b998571c2..a3e697daf 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ edison/wordnet/ **/.project **/.classpath +*~ diff --git a/README.md b/README.md index 1c2425fe6..4af2651c4 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,8 @@ Each library contains detailed readme and instructions on how to use it. In addi | [curator](curator/README.md) | Supports use of [CogComp NLP Curator](http://cogcomp.cs.illinois.edu/page/software_view/Curator), a tool to run NLP applications as services. | | [edison](edison/README.md) | A library for feature extraction from `core-utilities` data structures. | | [lemmatizer](lemmatizer/README.md) | An application that uses [WordNet](https://wordnet.princeton.edu/) and simple rules to find the root forms of words in plain text. | -| [tokenizer](tokenizer/README.md) | An application that identifies sentence and word boundaries in plain text. | +| [tokenizer](tokenizer/README.md) | An application that identifies sentence and word boundaries in plain text. | +| [transliteration](transliteration/README.md) | An application that transliterates names between different scripts. | | [pos](pos/README.md) | An application that identifies the part of speech (e.g. verb + tense, noun + number) of each word in plain text. | | [ner](ner/README.md) | An application that identifies named entities in plain text according to two different sets of categories. | | [md](md/README.md) | An application that identifies entity mentions in plain text. | diff --git a/core-utilities/src/main/java/edu/illinois/cs/cogcomp/annotation/Annotator.java b/core-utilities/src/main/java/edu/illinois/cs/cogcomp/annotation/Annotator.java index ca7a4352c..d9ab17040 100644 --- a/core-utilities/src/main/java/edu/illinois/cs/cogcomp/annotation/Annotator.java +++ b/core-utilities/src/main/java/edu/illinois/cs/cogcomp/annotation/Annotator.java @@ -14,6 +14,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.IOException; import java.util.Properties; import java.util.Set; diff --git a/core-utilities/src/main/java/edu/illinois/cs/cogcomp/core/datastructures/ViewNames.java b/core-utilities/src/main/java/edu/illinois/cs/cogcomp/core/datastructures/ViewNames.java index 68540804a..3b8bca274 100644 --- a/core-utilities/src/main/java/edu/illinois/cs/cogcomp/core/datastructures/ViewNames.java +++ b/core-utilities/src/main/java/edu/illinois/cs/cogcomp/core/datastructures/ViewNames.java @@ -128,6 +128,7 @@ public class ViewNames { public static final String POST_ERE = "POST_ERE"; public static final String EVENT_ERE = "EVENT_ERE"; + public static final String TRANSLITERATION = "TRANSLITERATION"; public static ViewTypes getViewType(String viewName) { switch (viewName) { diff --git a/pipeline/pom.xml b/pipeline/pom.xml index d1d1cef48..a78e8990a 100644 --- a/pipeline/pom.xml +++ b/pipeline/pom.xml @@ -136,6 +136,11 @@ illinois-time 3.1.33 + + edu.illinois.cs.cogcomp + illinois-transliteration + 3.1.33 + diff --git a/pipeline/src/main/java/edu/illinois/cs/cogcomp/pipeline/common/PipelineConfigurator.java b/pipeline/src/main/java/edu/illinois/cs/cogcomp/pipeline/common/PipelineConfigurator.java index b0de55fd5..611a12a9e 100644 --- a/pipeline/src/main/java/edu/illinois/cs/cogcomp/pipeline/common/PipelineConfigurator.java +++ b/pipeline/src/main/java/edu/illinois/cs/cogcomp/pipeline/common/PipelineConfigurator.java @@ -40,6 +40,7 @@ public class PipelineConfigurator extends AnnotatorServiceConfigurator { public static final Property USE_QUANTIFIER = new Property("useQuantifier", FALSE); public static final Property USE_VERB_SENSE = new Property("useVerbSense", FALSE); public static final Property USE_JSON = new Property("useJson", FALSE); + public static final Property USE_TRANSLITERATION = new Property("useTransliteration", FALSE); public static final Property USE_MENTION = new Property("useMention", FALSE); public static final Property USE_LAZY_INITIALIZATION = new Property( AnnotatorConfigurator.IS_LAZILY_INITIALIZED.key, TRUE); diff --git a/pipeline/src/main/java/edu/illinois/cs/cogcomp/pipeline/main/PipelineFactory.java b/pipeline/src/main/java/edu/illinois/cs/cogcomp/pipeline/main/PipelineFactory.java index 3a9a20e14..34a3a58c5 100644 --- a/pipeline/src/main/java/edu/illinois/cs/cogcomp/pipeline/main/PipelineFactory.java +++ b/pipeline/src/main/java/edu/illinois/cs/cogcomp/pipeline/main/PipelineFactory.java @@ -121,6 +121,10 @@ public static BasicAnnotatorService buildPipeline(Boolean disableCache, String.. nonDefaultValues.put(PipelineConfigurator.USE_VERB_SENSE.key, Configurator.TRUE); break; + case ViewNames.TRANSLITERATION: + nonDefaultValues.put(PipelineConfigurator.USE_TRANSLITERATION.key, + Configurator.TRUE); + break; case ViewNames.TIMEX3: nonDefaultValues.put(PipelineConfigurator.USE_TIMEX3.key, Configurator.TRUE); @@ -346,6 +350,11 @@ private static Map buildAnnotators(ResourceManager nonDefault viewGenerators.put(ViewNames.QUANTITIES, quantifierAnnotator); } + if (rm.getBoolean(PipelineConfigurator.USE_TRANSLITERATION)) { + TransliterationAnnotator transliterationAnnotator = new TransliterationAnnotator(); + viewGenerators.put(ViewNames.TRANSLITERATION, transliterationAnnotator); + } + if (rm.getBoolean(PipelineConfigurator.USE_SRL_PREP)) { PrepSRLAnnotator prepSRLAnnotator = new PrepSRLAnnotator(); viewGenerators.put(ViewNames.SRL_PREP, prepSRLAnnotator); diff --git a/pom.xml b/pom.xml index 5ee72755b..5fff1de9f 100644 --- a/pom.xml +++ b/pom.xml @@ -11,6 +11,7 @@ core-utilities tokenizer + transliteration lemmatizer edison curator diff --git a/transliteration/README.md b/transliteration/README.md new file mode 100644 index 000000000..9b58929ab --- /dev/null +++ b/transliteration/README.md @@ -0,0 +1,60 @@ +# Transliteration + +This is a Java port of Jeff Pasternack's C# code from [Learning Better Transliterations](http://cogcomp.org/page/publication_view/205) + +See examples in [TestTransliteration](src/test/java/edu/illinois/cs/cogcomp/transliteration/TestTransliteration.java) +or [Runner](src/main/java/edu/illinois/cs/cogcomp/transliteration/Runner.java). + + +## Training data + +To train a model, you need pairs of names. A common source is Wikipedia interlanguage links. For example, +see [this data](http://www.clsp.jhu.edu/~anni/data/wikipedia_names) +from [Transliterating From All Languages](http://cis.upenn.edu/~ccb/publications/transliterating-from-all-languages.pdf) +by Anne Irvine et al. + +The standard data format expected is: +```bash +foreignenglish +``` + +That said, the [Utils class](src/main/java/edu/illinois/cs/cogcomp/utils/Utils.java) has readers for many +different datasets (including Anne Irvine's data). + +## Training a model +The standard class is the [SPModel](src/main/java/edu/illinois/cs/cogcomp/transliteration/SPModel.java). Use it +as follows: + +```java +List training = Utils.readWikiData(trainfile); +SPModel model = new SPModel(training); +model.Train(10); +model.WriteProbs(modelfile); + +``` + +This will train a model, and write it to the path specified by `modelfile`. + +`SPModel` has another useful function called `Probability(source, target)`, which will return the transliteration probability +of a given pair. + +## Annotating +A trained model can be used immediately after training, or you can initialize `SPModel` using a +previously trained and saved `modelfile`. + +```java +SPModel model = new SPModel(modelfile); +model.setMaxCandidates(10); +TopList predictions = model.Generate(testexample); +``` + +We limited the max number of candidates to 10, so `predictions` will have at most 10 elements. These +are sorted by score, highest to lowest, where the first element is the best. + +## Interactive + +Once you have trained a model, it is often helpful to try interacting with it. Use [interactive.sh](scripts/interactive.sh) +for this: +```bash +$ ./scripts/interactive.sh models/modelfile +``` diff --git a/transliteration/config/project.properties b/transliteration/config/project.properties new file mode 100644 index 000000000..e5a4b3d79 --- /dev/null +++ b/transliteration/config/project.properties @@ -0,0 +1,3 @@ +# Use ResourceManager to read these properties +CuratorHost = trollope.cs.illinois.edu +CuratorPort = 9010 diff --git a/transliteration/pom.xml b/transliteration/pom.xml new file mode 100644 index 000000000..81276e690 --- /dev/null +++ b/transliteration/pom.xml @@ -0,0 +1,134 @@ + + + + illinois-cogcomp-nlp + edu.illinois.cs.cogcomp + 3.1.33 + + 4.0.0 + + illinois-transliteration + + + UTF-8 + UTF-8 + + + + + CogcompSoftware + CogcompSoftware + http://cogcomp.cs.illinois.edu/m2repo/ + + + + + + junit + junit + 3.8.1 + test + + + + edu.illinois.cs.cogcomp + illinois-core-utilities + 3.1.33 + + + + org.apache.commons + commons-lang3 + 3.4 + + + junit + junit + 4.12 + test + + + + com.belerweb + pinyin4j + 2.5.0 + + + + org.slf4j + slf4j-log4j12 + 1.7.13 + + + + com.ibm.icu + icu4j + 56.1 + + + + edu.illinois.cs.cogcomp + illinois-abstract-server + 0.1 + + + edu.illinois.cs.cogcomp + curator-interfaces + 0.7 + + + + org.apache.thrift + libthrift + 0.8.0 + + + edu.illinois.cs.cogcomp + curator-utils + 0.0.4-SNAPSHOT + + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 2.0.2 + + 1.7 + 1.7 + + + + org.apache.maven.plugins + maven-source-plugin + 2.1.2 + + + attach-sources + + jar + + + + + + + + src/main/resources + + + + + org.apache.maven.wagon + wagon-ssh + 2.4 + + + + + diff --git a/transliteration/scripts/interactive.sh b/transliteration/scripts/interactive.sh new file mode 100755 index 000000000..8af58c267 --- /dev/null +++ b/transliteration/scripts/interactive.sh @@ -0,0 +1,8 @@ +#!/bin/sh + +cpath="target/classes:target/dependency/*:config" +MODEL=$1 + +CMD="java -classpath ${cpath} -Xmx8g edu.illinois.cs.cogcomp.transliteration.Interactive $MODEL" +echo "Running: $CMD" +${CMD} diff --git a/transliteration/scripts/release.sh b/transliteration/scripts/release.sh new file mode 100755 index 000000000..02fe60f58 --- /dev/null +++ b/transliteration/scripts/release.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +if [ "$#" -ne 1 ]; then + echo "usage: $0 " + exit +fi + +# Get the current version +VERSION=`mvn org.apache.maven.plugins:maven-help-plugin:2.1.1:evaluate -Dexpression=project.version | grep -v INFO` + +## DON'T FORGET TO CHANGE VERSION IF THIS IS A NEW RELEASE!!! +PACKAGE_NAME=$1 + +echo "The script should run the following commands for package: ${PACKAGE_NAME}-${VERSION}" + +## Deploy the Maven release +echo "mvn javadoc:jar deploy" + +## Update the GitLab repository (also create a tag) +echo "git tag v${VERSION} -m \"Releasing ${PACKAGE_NAME}-${VERSION}\"" + +echo "git push --tags" + + +## Generate the distribution package +echo -n "Generating the distribution package ..." + +## Create a temporary directory +TEMP_DIR="temp90614" +PACKAGE_DIR="${TEMP_DIR}/${PACKAGE_NAME}-${VERSION}" + +mvn dependency:copy-dependencies + +mkdir -p ${PACKAGE_DIR} +mkdir ${PACKAGE_DIR}/lib +mkdir ${PACKAGE_DIR}/dist +mkdir -p ${PACKAGE_DIR}/doc/javadoc +mkdir ${PACKAGE_DIR}/src +mkdir ${PACKAGE_DIR}/scripts + +mv target/${PACKAGE_NAME}-${VERSION}.jar ${PACKAGE_DIR}/dist/ +mv target/${PACKAGE_NAME}-${VERSION}-sources.jar ${PACKAGE_DIR}/src/ +unzip target/${PACKAGE_NAME}-${VERSION}-javadoc.jar -d ${PACKAGE_DIR}/doc/javadoc +mv target/dependency/* ${PACKAGE_DIR}/lib/ +cp doc/* ${PACKAGE_DIR}/doc +cp scripts/* ${PACKAGE_DIR}/scripts + +cd ${TEMP_DIR} +zip -r ../${PACKAGE_NAME}.zip ${PACKAGE_NAME}-${VERSION} +cd .. + +rm -rf ${TEMP_DIR} +echo "Distribution package created: ${PACKAGE_NAME}.zip" diff --git a/transliteration/scripts/runner.sh b/transliteration/scripts/runner.sh new file mode 100755 index 000000000..05e6a870c --- /dev/null +++ b/transliteration/scripts/runner.sh @@ -0,0 +1,10 @@ +#!/bin/sh + +cpath="target/classes:target/dependency/*:config" +DIR="/path/to/transliteration/data" +TRAIN=$DIR/train.data +TEST=$DIR/test.data + +CMD="java -classpath ${cpath} -Xmx8g edu.illinois.cs.cogcomp.transliteration.Runner $TRAIN $TEST" +echo "Running: $CMD" +${CMD} diff --git a/transliteration/src/main/java/edu/illinois/cs/cogcomp/annotation/TransliterationAnnotator.java b/transliteration/src/main/java/edu/illinois/cs/cogcomp/annotation/TransliterationAnnotator.java new file mode 100644 index 000000000..37f8831b6 --- /dev/null +++ b/transliteration/src/main/java/edu/illinois/cs/cogcomp/annotation/TransliterationAnnotator.java @@ -0,0 +1,68 @@ +/** + * This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computation Group University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.annotation; + +import edu.illinois.cs.cogcomp.core.datastructures.Pair; +import edu.illinois.cs.cogcomp.core.datastructures.ViewNames; +import edu.illinois.cs.cogcomp.core.datastructures.textannotation.Constituent; +import edu.illinois.cs.cogcomp.core.datastructures.textannotation.TextAnnotation; +import edu.illinois.cs.cogcomp.core.datastructures.textannotation.TokenLabelView; +import edu.illinois.cs.cogcomp.core.datastructures.textannotation.View; +import edu.illinois.cs.cogcomp.core.utilities.configuration.ResourceManager; +import edu.illinois.cs.cogcomp.transliteration.SPModel; +import edu.illinois.cs.cogcomp.utils.TopList; +import jdk.nashorn.internal.parser.Token; + +import java.io.IOException; + +public class TransliterationAnnotator extends Annotator { + + SPModel model; + + public TransliterationAnnotator() { + super(ViewNames.TRANSLITERATION, new String[0]); + } + + public TransliterationAnnotator(boolean lazilyInitialize) { + super(ViewNames.TRANSLITERATION, new String[0], lazilyInitialize); + } + + @Override + public void initialize(ResourceManager rm) { + try { + model = new SPModel(rm.getString(TransliterationConfigurator.MODEL_PATH.key)); + model.setMaxCandidates(1); + } catch (IOException e) { + e.printStackTrace(); + } + } + + @Override + protected void addView(TextAnnotation ta) throws AnnotatorException { + + View v = new TokenLabelView(ViewNames.TRANSLITERATION, this.getClass().getName(), ta, 1.0); + + int index = 0; + for(String tok : ta.getTokens()){ + try { + TopList ll = model.Generate(tok.toLowerCase()); + if(ll.size() > 0) { + Pair toppair = ll.getFirst(); + Constituent c = new Constituent(toppair.getSecond(), toppair.getFirst(), ViewNames.TRANSLITERATION, ta, index, index + 1); + v.addConstituent(c); + } + } catch (Exception e) { + // print that this word has failed... + e.printStackTrace(); + } + + index++; + } + ta.addView(ViewNames.TRANSLITERATION, v); + } +} diff --git a/transliteration/src/main/java/edu/illinois/cs/cogcomp/annotation/TransliterationConfigurator.java b/transliteration/src/main/java/edu/illinois/cs/cogcomp/annotation/TransliterationConfigurator.java new file mode 100644 index 000000000..8b33f7eb4 --- /dev/null +++ b/transliteration/src/main/java/edu/illinois/cs/cogcomp/annotation/TransliterationConfigurator.java @@ -0,0 +1,31 @@ +package edu.illinois.cs.cogcomp.annotation; + +import edu.illinois.cs.cogcomp.core.utilities.configuration.Configurator; +import edu.illinois.cs.cogcomp.core.utilities.configuration.Property; +import edu.illinois.cs.cogcomp.core.utilities.configuration.ResourceManager; + +public class TransliterationConfigurator extends Configurator { + + public static final Property LANGUAGE = new Property("usePos", FALSE); + + public static final Property MODEL_PATH = new Property("transliterationModelPath", ""); // todo: fix this + + @Override + public ResourceManager getDefaultConfig() { + Property[] properties = { MODEL_PATH, LANGUAGE }; + return (new TransliterationConfigurator().getConfig(new ResourceManager(generateProperties(properties)))); + } + + /** + * Get a {@link ResourceManager} with non-default properties. Overloaded to merge the properties + * of {@link AnnotatorServiceConfigurator}. + * + * @param nonDefaultRm The non-default properties + * @return a non-null ResourceManager with appropriate values set. + */ + @Override + public ResourceManager getConfig(ResourceManager nonDefaultRm) { + ResourceManager pipelineRm = super.getConfig(nonDefaultRm); + return new AnnotatorServiceConfigurator().getConfig(pipelineRm); + } +} diff --git a/transliteration/src/main/java/edu/illinois/cs/cogcomp/annotation/TransliterationHandler.java b/transliteration/src/main/java/edu/illinois/cs/cogcomp/annotation/TransliterationHandler.java new file mode 100644 index 000000000..98f1b249b --- /dev/null +++ b/transliteration/src/main/java/edu/illinois/cs/cogcomp/annotation/TransliterationHandler.java @@ -0,0 +1,122 @@ +/** + * This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computation Group University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.annotation; + +import edu.illinois.cs.cogcomp.annotation.handler.IllinoisAbstractHandler; +import edu.illinois.cs.cogcomp.core.datastructures.Pair; +import edu.illinois.cs.cogcomp.thrift.base.AnnotationFailedException; +import edu.illinois.cs.cogcomp.thrift.base.Labeling; +import edu.illinois.cs.cogcomp.thrift.base.Span; +import edu.illinois.cs.cogcomp.thrift.curator.Record; +import edu.illinois.cs.cogcomp.thrift.labeler.Labeler; +import edu.illinois.cs.cogcomp.transliteration.Example; +import edu.illinois.cs.cogcomp.transliteration.SPModel; +import edu.illinois.cs.cogcomp.utils.TopList; +import edu.illinois.cs.cogcomp.utils.Utils; +import org.apache.thrift.TException; +import edu.illinois.cs.cogcomp.curator.RecordGenerator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.FileNotFoundException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + +/** + * Created by mayhew2 on 1/14/16. + */ +public class TransliterationHandler extends IllinoisAbstractHandler implements Labeler.Iface { + private static Logger logger = LoggerFactory.getLogger(TransliterationHandler.class); + private static final String ARTIFACTNAME = "illinois-transliteration"; + private static final String VERSION = "1.0.0"; + private static final String PACKAGENAME = "Illinois Transliteration"; + + private SPModel model; + + public TransliterationHandler(String trainfilename) throws FileNotFoundException { + super(PACKAGENAME, VERSION, ARTIFACTNAME); + + boolean fix = true; + + List training = Utils.readWikiData(trainfilename, fix); + + model = new SPModel(training.subList(0,200)); + model.Train(2); + + + } + + @Override + public Labeling labelRecord(Record record) throws AnnotationFailedException, TException { + + String text = record.getRawText().toLowerCase(); + List words = Arrays.asList(text.split(" ")); + + int limit = 5; + if (words.size() > limit){ + logger.info("Transliteration handler does not handle text with more than {} tokens. Just annotating the first {}...", limit, limit); + words = words.subList(0,limit); + } + + List labels = new ArrayList<>(); + + int i = 0; + for(String name : words) { + + if(name.length() == 0){ + i += 1; + continue; + } + + logger.debug("[" + name + "]"); + + TopList res = null; + + Span span = new Span(); + span.setStart(i); + span.setEnding(i + name.length()); + try { + res = model.Generate(name); + Pair best = res.getFirst(); + span.setLabel(best.getSecond()); + + } catch (Exception e) { + e.printStackTrace(); + span.setLabel("UNLABELED"); + } + + labels.add(span); + + i += name.length()+1; // extra 1 is for spaces... + } + + Labeling labeling = new Labeling(); + labeling.setSource(getSourceIdentifier()); + labeling.setLabels(labels); + return labeling; + } + + public static void main(String[] args) throws Exception { + String text = "whatever"; + + String wikidata = "/shared/corpora/transliteration/wikidata/wikidata.Hindi"; + + TransliterationHandler handler = new TransliterationHandler(wikidata); + Record input = RecordGenerator.generateTokenRecord( text, false ); + + Labeling labels = handler.labelRecord(input); + for(Iterator label = labels.getLabelsIterator(); label.hasNext() ; ) { + Span span = label.next(); + System.out.println("["+span.start+"-"+span.ending+"]"); + System.out.println(text.substring(span.start, span.ending)+"\t:\t"+span.getLabel()); + } + } + +} diff --git a/transliteration/src/main/java/edu/illinois/cs/cogcomp/annotation/TransliterationServer.java b/transliteration/src/main/java/edu/illinois/cs/cogcomp/annotation/TransliterationServer.java new file mode 100644 index 000000000..b5d938cd9 --- /dev/null +++ b/transliteration/src/main/java/edu/illinois/cs/cogcomp/annotation/TransliterationServer.java @@ -0,0 +1,45 @@ +/** + * This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computation Group University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.annotation; + +import edu.illinois.cs.cogcomp.annotation.server.IllinoisAbstractServer; +import edu.illinois.cs.cogcomp.thrift.labeler.Labeler; +import org.apache.commons.cli.Options; + +/** + * Created by mayhew2 on 1/14/16. + */ +public class TransliterationServer extends IllinoisAbstractServer { + public TransliterationServer(Class c) { + super(c); + } + + public static void main(String[] args) { + + TransliterationServer s = new TransliterationServer(TransliterationServer.class); + + Options options = createOptions(); + + s.parseCommandLine(options, args); + + Labeler.Iface handler = null; + Labeler.Processor processor = null; + + String trainingfile = "/shared/corpora/transliteration/wikidata/wikidata.Hindi"; + try { + handler = new TransliterationHandler(trainingfile); + processor = new Labeler.Processor(handler); + } catch (Exception e) { + s.logger.error("Couldn't start the handler.... the exception was\n"+e.toString(), e.toString()); + System.exit(0); + } + + s.runServer(processor); + } + +} diff --git a/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/CSPModel.java b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/CSPModel.java new file mode 100644 index 000000000..7010cc600 --- /dev/null +++ b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/CSPModel.java @@ -0,0 +1,118 @@ +/** + * This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computation Group University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.transliteration; + +import edu.illinois.cs.cogcomp.core.datastructures.Pair; +import edu.illinois.cs.cogcomp.core.datastructures.Triple; +import edu.illinois.cs.cogcomp.utils.SparseDoubleVector; + +import java.util.List; + +/** + * Created by stephen on 9/25/15. + */ +// also extends iCloneable. Hmm. +class CSPModel extends TransliterationModel +{ + + enum SegMode + { + None, + Count, + Entropy, + Best + } + + enum SmoothMode + { + BySource, //smooth based on length of the source substring only + ByMax, //smooth based on the maximum of lengths of source and target substrings + BySum //smooth based on the sum of lengths of the source and target substrings + } + + enum EMMode + { + Normal, + MaxSourceSeg, //assume every source segment is valid (ex.: not true for "p" or "h" in "phone") and, in each example, find the "true" generated target language substring by giving a weight of 1 to the most likely production, and 0 to everything else + Smoothed, //apply smoothing in EM + BySourceSeg + } + + public CSPModel() { + } + + /** + * + * @param maxSubstringLength + * @param segContextSize + * @param productionContextSize + * @param minProductionProbability + * @param segMode + * @param syllabic + * @param smoothMode + * @param fallbackStrategy + * @param emMode + * @param underflowChecking True to check for \sum_t P(t|s) == 0 after normalizing production counts, where t and s are segments of the target and source word, respectively. If such (total) underflow occurs, the previous iteration's conditional probabilities are used instead. + */ + public CSPModel(int maxSubstringLength, int segContextSize, int productionContextSize, double minProductionProbability, SegMode segMode, boolean syllabic, SmoothMode smoothMode, FallbackStrategy fallbackStrategy, EMMode emMode, boolean underflowChecking) + { + this.maxSubstringLength = maxSubstringLength; + this.segContextSize = segContextSize; + this.productionContextSize = productionContextSize; + this.minProductionProbability = minProductionProbability; + this.fallbackStrategy = fallbackStrategy; + this.syllabic = syllabic; + //this.updateSegProbs = updateSegProbs; + this.segMode = segMode; + this.smoothMode = smoothMode; + this.emMode = emMode; + this.underflowChecking = underflowChecking; + } + + public Boolean underflowChecking; + + public EMMode emMode; + public SegMode segMode; + public SmoothMode smoothMode; + + //public bool updateSegProbs; + public Boolean syllabic; + + public SparseDoubleVector, String>> productionProbs; + public SparseDoubleVector> segProbs; + + //public SparseDoubleVector, String>> productionCounts; + //public SparseDoubleVector> segCounts; + + public int segContextSize; + public int productionContextSize; + public int maxSubstringLength; + + public double minProductionProbability; + + public FallbackStrategy fallbackStrategy; + + public Object clone() + { + return this.clone(); + } + + + @Override + public double GetProbability(String word1, String word2) + { + return CSPTransliteration.GetProbability(word1, word2, this); + } + + @Override + public TransliterationModel LearnModel(List> examples) { + return CSPTransliteration.LearnModel(examples, this); + } + + +} diff --git a/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/CSPTransliteration.java b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/CSPTransliteration.java new file mode 100644 index 000000000..b8cb4fd7b --- /dev/null +++ b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/CSPTransliteration.java @@ -0,0 +1,648 @@ +/** + * This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computation Group University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.transliteration; + +import edu.illinois.cs.cogcomp.core.datastructures.Pair; +import edu.illinois.cs.cogcomp.core.datastructures.Triple; +import edu.illinois.cs.cogcomp.utils.InternDictionary; +import edu.illinois.cs.cogcomp.utils.SparseDoubleVector; +import org.apache.commons.lang3.StringUtils; + +import java.lang.reflect.Array; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; + +/** + * I don't know what CSPTransliteration or CSPModel are for. --Stephen + */ +class CSPTransliteration +{ + + public static SparseDoubleVector, String>> CreateFallbackPair(SparseDoubleVector, String>> counts, InternDictionary internTable) + { + SparseDoubleVector, String>> result = new SparseDoubleVector<>(); + for (Pair, String> key : counts.keySet()) + { + Double value = counts.get(key); + for (int i = 0; i <= key.getFirst().getFirst().length(); i++) + { + Pair, String> reskey = new Pair<>( + new Triple<>( + internTable.Intern(key.getFirst().getFirst().substring(i)), + internTable.Intern(key.getFirst().getSecond()), + internTable.Intern(key.getFirst().getThird().substring(0, key.getFirst().getThird().length() - i))), key.getSecond()); + result.put(reskey, result.get(reskey) + value); + + } + } + + return result; + } + + public static SparseDoubleVector> CreateFallback(SparseDoubleVector> segCounts, InternDictionary internTable) + { + SparseDoubleVector> result = new SparseDoubleVector<>(); + for (Triple key : segCounts.keySet()) + { + Double value = segCounts.get(key); + for (int i = 0; i <= key.getFirst().length(); i++) + { + Triple reskey = new Triple<>(internTable.Intern(key.getFirst().substring(i)), + internTable.Intern(key.getSecond()), + internTable.Intern(key.getThird().substring(0, key.getThird().length() - i))); + result.put(reskey, result.get(reskey) + value); + } + } + + return result; + } + + public static CSPModel LearnModel(List> examples, CSPModel model) + { + // odd notation... + CSPExampleCounts counts = new CSPTransliteration().new CSPExampleCounts(); + //CSPExampleCounts intCounts = new CSPExampleCounts(); + + counts.productionCounts = new SparseDoubleVector<>(); + counts.segCounts = new SparseDoubleVector<>(); + + for (Triple example : examples) + { + CSPExampleCounts exampleCount = LearnModel(example.getFirst(), example.getSecond(), model); + counts.productionCounts.put(example.getThird(), exampleCount.productionCounts); + counts.segCounts.put(example.getThird(), exampleCount.segCounts); + + //intCounts.productionCounts += exampleCount.productionCounts.Sign(); + //intCounts.segCounts += exampleCount.segCounts.Sign(); + } + + CSPModel result = (CSPModel)model.clone(); + + //normalize to get "joint" + result.productionProbs = counts.productionCounts.divide(counts.productionCounts.PNorm(1)); + + InternDictionary internTable = new InternDictionary<>(); + + SparseDoubleVector, String>> oldProbs=null; + if (model.underflowChecking) + oldProbs = result.productionProbs; + + //now get production fallbacks + result.productionProbs = CreateFallbackPair(result.productionProbs, internTable); + + //finally, make it conditional + result.productionProbs = PSecondGivenFirst(result.productionProbs); + + if (model.underflowChecking) + { + //go through and ensure that Sum_X(P(Y|X)) == 1...or at least > 0! + SparseDoubleVector> sums = new SparseDoubleVector<>(); + for (Pair, String> key : model.productionProbs.keySet()) { + Double value = model.productionProbs.get(key); + + Triple ff = key.getFirst(); + + Double val = sums.get(ff); + sums.put(ff, val + value); + } + + List, String>> restoreList = new ArrayList<>(); + + for (Pair, String> key : model.productionProbs.keySet()) { + Double value = model.productionProbs.get(key); + + if (value == 0 && sums.get(key.getFirst()) == 0) + restoreList.add(key); + } + + + for (Pair, String> pair : restoreList) { + //model.productionProbs[pair] = oldProbs[pair]; + model.productionProbs.put(pair, oldProbs.get(pair)); + } + + } + + //get conditional segmentation probs + if (model.segMode == CSPModel.SegMode.Count) + result.segProbs = CreateFallback(PSegGivenFlatOccurence(counts.segCounts,examples,model.segContextSize),internTable); // counts.segCounts / counts.segCounts.PNorm(1); + else if (model.segMode == CSPModel.SegMode.Entropy) + { + SparseDoubleVector> totals = new SparseDoubleVector<>(); + for (Pair, String> key : result.productionProbs.keySet()) + { + Double value = result.productionProbs.get(key); + + //totals[pair.Key.x] -= pair.Value * Math.Log(pair.Value, 2); + double logValue = value * Math.log(value); + + if (!Double.isNaN(logValue)) { + Double lv = totals.get(key.getFirst()); + totals.put(key.getFirst(), lv + logValue); + } + //totals[pair.Key.x] *= Math.Pow(pair.Value, pair.Value); + } + + result.segProbs = totals.Exp(); //totals.Max(0.000001).Pow(-1); + } + + //return the finished model + return result; + } + + private static Triple GetContextTriple(String originalWord, int index, int length, int contextSize) + { + return new Triple<>(WikiTransliteration.GetLeftContext(originalWord, index, contextSize), + originalWord.substring(index, length), + WikiTransliteration.GetRightContext(originalWord, index + length, contextSize)); + } + + public static SparseDoubleVector> PSegGivenFlatOccurence(SparseDoubleVector> segCounts, List> examples, int contextSize) + { + SparseDoubleVector> counts = new SparseDoubleVector<>(); + for (Triple example : examples) + { + String word = StringUtils.repeat('_', contextSize) + example.getFirst() + StringUtils.repeat('_', contextSize); + for (int i = contextSize; i < word.length() - contextSize; i++) { + for (int j = i; j < word.length() - contextSize; j++) { + Triple gct = GetContextTriple(word, i, j - i + 1, contextSize); + counts.put(gct, counts.get(gct) + example.getThird()); + } + } + } + + return segCounts.divide(counts); + } +// +// public static SparseDoubleVector> PSegGivenLength(SparseDoubleVector> segCounts) +// { +// SparseDoubleVector sums = new SparseDoubleVector<>(); +// for (Pair, Double> pair : segCounts) +// sums[pair.Key.y.Length] += pair.Value; +// +// SparseDoubleVector> result = new SparseDoubleVector>(segCounts.Count); +// for (Pair, Double> pair : segCounts) +// result[pair.Key] = pair.Value / sums[pair.Key.y.Length]; +// +// return result; +// } + + public static SparseDoubleVector, String>> PSecondGivenFirst(SparseDoubleVector, String>> counts) + { + SparseDoubleVector> totals = new SparseDoubleVector<>(); + for (Pair, String> key : counts.keySet()) { + Double value = counts.get(key); + totals.put(key.getFirst(), totals.get(key.getFirst()) + value); + } + + SparseDoubleVector, String>> result = new SparseDoubleVector<>(counts.size()); + for (Pair, String> key : counts.keySet()) + { + Double value = counts.get(key); + double total = totals.get(key.getFirst()); + if (total == 0) { + result.put(key, 0.0); + }else{ + result.put(key, value/total); + } + } + + return result; + } + + public static CSPExampleCounts LearnModel(String word1, String word2, CSPModel model) + { + CSPExampleCounts result = new CSPTransliteration().new CSPExampleCounts(); + + int paddingSize = Math.max(model.productionContextSize, model.segContextSize); + String paddedWord = StringUtils.repeat('_', paddingSize) + word1 + StringUtils.repeat('_', paddingSize); + HashMap,Pair>,Double>> lastArg = new HashMap<>(); + + Pair>, Double> raw = LearnModel(paddingSize, paddedWord, word1, word2, model, lastArg); + + if (raw.getSecond() == 0) + raw.setFirst(new SparseDoubleVector>()); + else + //raw.x = Program.segSums[Math.Min(39,word1.length()-1)][Math.Min(39,word2.length()-1)] * (raw.x); + raw.setFirst( raw.getFirst().divide(raw.getSecond()) ); + //raw.x = raw.y >= 1 ? raw.x : raw.x / raw.y; + + if (model.emMode == CSPModel.EMMode.MaxSourceSeg) + { + HashMap, Triple> bestProdProbs = new HashMap<>(); + SparseDoubleVector> maxProdProbs = new SparseDoubleVector<>(); + for (Triple key : raw.getFirst().keySet()) { + Double value = raw.getFirst().get(key); + + Pair keyXY = new Pair<>(key.getFirst(), key.getSecond()); + + if (maxProdProbs.get(keyXY) < value) { + bestProdProbs.put(keyXY, key); + maxProdProbs.put(keyXY, value); + } + } + + raw.getFirst().Clear(); + for (Triple triple : bestProdProbs.values()) + raw.getFirst().put(triple, 1.0); + + } + else if (model.emMode == CSPModel.EMMode.BySourceSeg) + { + //Dictionary, Triple> bestProdProbs = new Dictionary, Triple>(); + SparseDoubleVector> sumProdProbs = new SparseDoubleVector<>(); + for (Triple key : raw.getFirst().keySet()) { + Double value = raw.getFirst().get(key); + Pair keyXY = new Pair<>(key.getFirst(), key.getSecond()); + sumProdProbs.put(keyXY, sumProdProbs.get(keyXY) + value); + } + + SparseDoubleVector> newCounts = new SparseDoubleVector<>(raw.getFirst().size()); + for (Triple key : raw.getFirst().keySet()) { + Double value = raw.getFirst().get(key); + Pair keyXY = new Pair<>(key.getFirst(), key.getSecond()); + newCounts.put(key, value / sumProdProbs.get(keyXY)); + } + raw.setFirst(newCounts); + } + + result.productionCounts = new SparseDoubleVector<>(raw.getFirst().size()); + result.segCounts = new SparseDoubleVector<>(raw.getFirst().size()); + + for (Triple key : raw.getFirst().keySet()) { + Double value = raw.getFirst().get(key); + Pair, String> pckey = new Pair<>(new Triple<>(WikiTransliteration.GetLeftContext(paddedWord, key.getFirst(), model.productionContextSize), key.getSecond(), WikiTransliteration.GetRightContext(paddedWord, key.getFirst() + key.getSecond().length(), model.productionContextSize)), key.getThird()); + result.productionCounts.put(pckey, result.productionCounts.get(pckey) + value); + + Triple sckey = new Triple<>(WikiTransliteration.GetLeftContext(paddedWord, key.getFirst(), model.segContextSize), key.getSecond(), WikiTransliteration.GetRightContext(paddedWord, key.getFirst() + key.getSecond().length(), model.segContextSize)); + result.segCounts.put(sckey, result.segCounts.get(sckey) + value); + } + + return result; + } + + public static char[] vowels = new char[] { 'a', 'e', 'i', 'o', 'u', 'y' }; + + /** + * Gets counts for productions by (conceptually) summing over all the possible alignments + * and weighing each alignment (and its constituent productions) by the given probability table. + * probSum is important (and memoized for input word pairs)--it keeps track and returns the sum of the probabilities of all possible alignments for the word pair + * @param position ? + * @param originalWord1 ? + * @param word1 ? + * @param word2 ? + * @param model ? + * @param memoizationTable ? + * @return ? + */ + public static Pair>, Double> LearnModel(int position, String originalWord1, String word1, String word2, CSPModel model, HashMap, Pair>, Double>> memoizationTable) + { + Pair>, Double> memoization; + + Triple check = new Triple<>(position, word1, word2); + if(memoizationTable.containsKey(check)){ + return memoizationTable.get(check); + } + + Pair>, Double> result + = new Pair<>(new SparseDoubleVector>(), 0.0); + + if (word1.length() == 0 && word2.length() == 0) //record probabilities + { + result.setSecond(1.0); //null -> null is always a perfect alignment + return result; //end of the line + } + + int maxSubstringLength1f = Math.min(word1.length(), model.maxSubstringLength); + int maxSubstringLength2f = Math.min(word2.length(), model.maxSubstringLength); + + String[] leftContexts = WikiTransliteration.GetLeftFallbackContexts(originalWord1,position, Math.max(model.segContextSize, model.productionContextSize)); + + int firstVowel = -1; int secondVowel = -1; + if (model.syllabic) + { + for (int i = 0; i < word1.length(); i++) { + + if (Arrays.asList(vowels).contains(word1.charAt(i))){ + firstVowel = i; + }else if(firstVowel >= 0){ + break; + } + } + + if (firstVowel == -1) + firstVowel = word1.length() - 1; //no vowels! + + for (int i = firstVowel + 1; i < word1.length(); i++) { + if (Arrays.asList(vowels).contains(word1.charAt(i))) { + secondVowel = i; + break; + } + } + + if (secondVowel == -1 || (secondVowel == word1.length() - 1 && word1.charAt(secondVowel) == 'e')) //if only one vowel, only consider the entire thing; note consideration of silent 'e' at end of words + { + firstVowel = maxSubstringLength1f - 1; + secondVowel = maxSubstringLength1f; + } + } + else + { + firstVowel = 0; + secondVowel = maxSubstringLength1f; + } + + for (int i = firstVowel + 1; i <= secondVowel; i++) + //for (int i = 1; i <= maxSubstringLength1f; i++) //for each possible substring in the first word... + { + String substring1 = word1.substring(0, i); + String[] rightContexts = WikiTransliteration.GetRightFallbackContexts(originalWord1, position + i, Math.max(model.segContextSize, model.productionContextSize)); + + double segProb; + if (model.segProbs.size() == 0) + segProb = 1; + else + { + segProb = 0; + int minK = model.fallbackStrategy == FallbackStrategy.NotDuringTraining ? model.segContextSize : 0; + for (int k = model.segContextSize; k >= minK; k--){ + if (model.segProbs.containsKey(new Triple<>(leftContexts[k], substring1, rightContexts[k]))){ + segProb = model.segProbs.get(new Triple<>(leftContexts[k], substring1, rightContexts[k])); + break; + } + } + } + + for (int j = 1; j <= maxSubstringLength2f; j++) //foreach possible substring in the second + { + if ((word1.length() - i) * model.maxSubstringLength >= word2.length() - j && (word2.length() - j) * model.maxSubstringLength >= word1.length() - i) //if we get rid of these characters, can we still cover the remainder of word2? + { + String substring2 = word2.substring(0, j); + //Pair, String> production = new Pair, String>(new Triple(leftProductionContext, substring1, rightProductionContext), substring2); + + double prob; + if (model.productionProbs.size() == 0) + prob = 1; + else + { + prob = 0; + int minK = model.fallbackStrategy == FallbackStrategy.NotDuringTraining ? model.productionContextSize : 0; + for (int k = model.productionContextSize; k >= minK; k--) { + Pair, String> v = new Pair<>(new Triple<>(leftContexts[k], substring1, rightContexts[k]), substring2); + if (model.productionProbs.containsKey(v)) { + prob = model.productionProbs.get(v); + break; + } + } + + if (model.emMode == CSPModel.EMMode.Smoothed) prob = Math.max(model.minProductionProbability, prob); + } + + Pair>, Double> remainder = LearnModel(position + i, originalWord1, word1.substring(i), word2.substring(j), model, memoizationTable); + + double cProb = prob * segProb; + + //record this production in our results + + result.getFirst().put(cProb, remainder.getFirst()); + result.setSecond(result.getSecond() + remainder.getSecond() * cProb); + + Triple pp = new Triple<>(position, substring1, substring2); + result.getFirst().put(pp, result.getFirst().get(pp) + cProb * remainder.getSecond()); + } + } + } + + memoizationTable.put(new Triple<>((double)position, word1, word2), result); + return result; + } + + + public static double GetProbability(String word1, String word2, CSPModel model) + { + int paddingSize = Math.max(model.productionContextSize, model.segContextSize); + String paddedWord = StringUtils.repeat('_', paddingSize) + word1 + StringUtils.repeat('_', paddingSize); + + if (model.segMode != CSPModel.SegMode.Best) + { + Pair raw = GetProbability(paddingSize, paddedWord, word1, word2, model, new HashMap, Pair>()); + return raw.getFirst() / raw.getSecond(); //normalize the segmentation probabilities by dividing by the sum of probabilities for all segmentations + } + else + return GetBestProbability(paddingSize, paddedWord, word1, word2, model, new HashMap, Double>()); + } + + //Gets the "best" alignment for a given word pair, defined as max P(s,t|S,T). + public static double GetBestProbability(int position, String originalWord1, String word1, String word2, CSPModel model, HashMap, Double> memoizationTable) + { + double result; + Triple v = new Triple(position, word1, word2); + if (memoizationTable.containsKey(v)){ + return memoizationTable.get(v); //we've been down this road before + } + + + result = 0; + + if (word1.length() == 0 && word2.length() == 0) + return 1; //perfect null-to-null alignment + + + int maxSubstringLength1f = Math.min(word1.length(), model.maxSubstringLength); + int maxSubstringLength2f = Math.min(word2.length(), model.maxSubstringLength); + + String[] leftContexts = WikiTransliteration.GetLeftFallbackContexts(originalWord1, position, Math.max(model.segContextSize, model.productionContextSize)); + + double minProductionProbability1 = 1; + + for (int i = 1; i <= maxSubstringLength1f; i++) //for each possible substring in the first word... + { + minProductionProbability1 *= model.minProductionProbability; + + String substring1 = word1.substring(0, i); + String[] rightContexts = WikiTransliteration.GetRightFallbackContexts(originalWord1, position + i, Math.max(model.segContextSize, model.productionContextSize)); + + double segProb; + if (model.segProbs.size() == 0) + segProb = 1; + else + { + segProb = 0; + for (int k = model.productionContextSize; k >= 0; k--) { + Triple v5 = new Triple<>(leftContexts[k], substring1, rightContexts[k]); + if (model.segProbs.containsKey(v5)) { + segProb = model.segProbs.get(v5); + break; + } + } + } + + double minProductionProbability2 = 1; + for (int j = 1; j <= maxSubstringLength2f; j++) //foreach possible substring in the second + { + minProductionProbability2 *= model.minProductionProbability; + + if ((word1.length() - i) * model.maxSubstringLength >= word2.length() - j && (word2.length() - j) * model.maxSubstringLength >= word1.length() - i) //if we get rid of these characters, can we still cover the remainder of word2? + { + double minProductionProbability; + if (model.smoothMode == CSPModel.SmoothMode.BySource) + minProductionProbability = minProductionProbability1; + else if (model.smoothMode == CSPModel.SmoothMode.ByMax) + minProductionProbability = Math.min(minProductionProbability1, minProductionProbability2); + else //if (model.smoothMode == SmoothMode.BySum) + minProductionProbability = minProductionProbability1 * minProductionProbability2; + + String substring2 = word2.substring(0, j); + //Pair, String> production = new Pair, String>(new Triple(leftProductionContext, substring1, rightProductionContext), substring2); + + double prob; + if (model.productionProbs.size() == 0) + prob = 1; + else + { + prob = 0; + for (int k = model.productionContextSize; k >= 0; k--) { + Pair, String> v4 = new Pair<>(new Triple<>(leftContexts[k], substring1, rightContexts[k]), substring2); + if (model.productionProbs.containsKey(v4)){ + prob = model.productionProbs.get(v4); + break; + } + + } + + prob = Math.max(prob, minProductionProbability); + } + + double remainder = prob * GetBestProbability(position + i, originalWord1, word1.substring(i), word2.substring(j), model, memoizationTable); + + if (remainder > result) result = remainder; //maximize + + //record this remainder in our results + //result.x += remainder.x * prob * segProb; + //result.y += remainder.y * segProb; + } + } + } + + memoizationTable.put(new Triple<>(position, word1, word2), result); + return result; + } + +// //Gets counts for productions by (conceptually) summing over all the possible alignments +// //and weighing each alignment (and its constituent productions) by the given probability table. +// //probSum is important (and memoized for input word pairs)--it keeps track and returns the sum of the probabilities of all possible alignments for the word pair + public static Pair GetProbability(int position, String originalWord1, String word1, String word2, CSPModel model, HashMap, Pair> memoizationTable) + { + Pair result; + Triple v = new Triple<>(position, word1, word2); + if(memoizationTable.containsKey(v)){ + return memoizationTable.get(v); + } + + result = new Pair<>(0.0, 0.0); + + if (word1.length() == 0 && word2.length() == 0) //record probabilities + { + result.setFirst(1.0); //null -> null is always a perfect alignment + result.setSecond(1.0); + return result; //end of the line + } + + int maxSubstringLength1f = Math.min(word1.length(), model.maxSubstringLength); + int maxSubstringLength2f = Math.min(word2.length(), model.maxSubstringLength); + + String[] leftContexts = WikiTransliteration.GetLeftFallbackContexts(originalWord1, position, Math.max(model.segContextSize, model.productionContextSize)); + + double minProductionProbability1 = 1; + + for (int i = 1; i <= maxSubstringLength1f; i++) //for each possible substring in the first word... + { + minProductionProbability1 *= model.minProductionProbability; + + String substring1 = word1.substring(0, i); + String[] rightContexts = WikiTransliteration.GetRightFallbackContexts(originalWord1, position + i, Math.max(model.segContextSize, model.productionContextSize)); + + double segProb; + if (model.segProbs.size() == 0) + segProb = 1; + else + { + segProb = 0; + for (int k = model.productionContextSize; k >= 0; k--) { + + Triple v2 = new Triple<>(leftContexts[k], substring1, rightContexts[k]); + if (model.segProbs.containsKey(v2)) { + segProb = model.segProbs.get(v2); + break; + } + } + } + + double minProductionProbability2 = 1; + for (int j = 1; j <= maxSubstringLength2f; j++) //foreach possible substring in the second + { + minProductionProbability2 *= model.minProductionProbability; + + if ((word1.length() - i) * model.maxSubstringLength >= word2.length() - j && (word2.length() - j) * model.maxSubstringLength >= word1.length() - i) //if we get rid of these characters, can we still cover the remainder of word2? + { + double minProductionProbability; + if (model.smoothMode == CSPModel.SmoothMode.BySource) + minProductionProbability = minProductionProbability1; + else if (model.smoothMode == CSPModel.SmoothMode.ByMax) + minProductionProbability = Math.min(minProductionProbability1, minProductionProbability2); + else //if (model.smoothMode == SmoothMode.BySum) + minProductionProbability = minProductionProbability1 * minProductionProbability2; + + String substring2 = word2.substring(0, j); + //Pair, String> production = new Pair, String>(new Triple(leftProductionContext, substring1, rightProductionContext), substring2); + + double prob; + if (model.productionProbs.size() == 0) + prob = 1; + else + { + prob = 0; + for (int k = model.productionContextSize; k >= 0; k--) { + Pair, String> v3 = new Pair<>(new Triple<>(leftContexts[k], substring1, rightContexts[k]), substring2); + if (model.productionProbs.containsKey(v3)) { + prob = model.productionProbs.get(v3); + break; + } + } + + prob = Math.max(prob, minProductionProbability); + } + + Pair remainder = GetProbability(position + i, originalWord1, word1.substring(i), word2.substring(j), model, memoizationTable); + + //record this remainder in our results + result.setFirst(result.getFirst() + remainder.getFirst() * prob * segProb); + result.setSecond(result.getSecond() + remainder.getSecond() * segProb); + + } + } + } + + memoizationTable.put(new Triple<>(position, word1, word2), result); + return result; + } + + + + class CSPExampleCounts + { + public SparseDoubleVector, String>> productionCounts; + public SparseDoubleVector> segCounts; + } + +} + + + diff --git a/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/Context.java b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/Context.java new file mode 100644 index 000000000..90fd555c8 --- /dev/null +++ b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/Context.java @@ -0,0 +1,22 @@ +/** + * This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computation Group University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.transliteration; + +// This file used to be called ContextualWord +// This used to be a struct. + +public class Context { + public Context(String leftContext, String rightContext) { + this.leftContext = leftContext; + this.rightContext = rightContext; + } + + public String leftContext; + public String rightContext; +} + diff --git a/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/Example.java b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/Example.java new file mode 100644 index 000000000..9e2006c91 --- /dev/null +++ b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/Example.java @@ -0,0 +1,61 @@ +/** + * This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computation Group University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.transliteration; + +import edu.illinois.cs.cogcomp.core.datastructures.Triple; + +/** + * Represents a training example for the transliteration model. Examples may be weighted to make them more or less + * relatively important. + */ +public class Example extends MultiExample { + + /** + * Creates a new training example with weight 1. + * @param sourceWord + * @param transliteratedWord + */ + public Example(String sourceWord, String transliteratedWord) + { + this(sourceWord, transliteratedWord, 1); + } + + /** + * Creates a new training example with the specified weight. + * @param sourceWord + * @param transliteratedWord + * @param weight + */ + public Example(String sourceWord, String transliteratedWord, double weight) { + super(sourceWord, transliteratedWord, weight); + } + + public String getTransliteratedWord(){ + return this.transliteratedWords.get(0); + } + + /** + * This used to be a field, with a get{} method. + * @return + */ + Triple Triple() + { + return new Triple<>(sourceWord, transliteratedWords.get(0), weight); + } + + /** + * Gets a "reversed" copy of this example, with the source and transliterated words swapped. + * This used to be a field, with a get() method. + */ + public Example Reverse(){ + return new Example(transliteratedWords.get(0), sourceWord, weight); + } + + +} + diff --git a/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/FallbackStrategy.java b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/FallbackStrategy.java new file mode 100644 index 000000000..8ab75ea9b --- /dev/null +++ b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/FallbackStrategy.java @@ -0,0 +1,15 @@ +/** + * This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computation Group University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.transliteration; + +enum FallbackStrategy { + Standard, + Average, + NotDuringTraining +} + diff --git a/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/Interactive.java b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/Interactive.java new file mode 100644 index 000000000..46de1b8e9 --- /dev/null +++ b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/Interactive.java @@ -0,0 +1,62 @@ +/** + * This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computation Group University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.transliteration; + +import edu.illinois.cs.cogcomp.core.datastructures.Pair; +import edu.illinois.cs.cogcomp.utils.TopList; + +import java.util.Iterator; +import java.util.Scanner; + +/** + * Created by mayhew2 on 5/30/16. + */ +public class Interactive { + + public static void main(String[] args) throws Exception { + String modelname = args[0]; + + interactive(modelname); + } + + static void interactive(String modelname) throws Exception { + SPModel model = new SPModel(modelname); + + //List arabicStrings = Program.getForeignWords(wikidata + "wikidata.Armenian"); + //model.SetLanguageModel(arabicStrings); + + Scanner scanner = new Scanner(System.in); + + while(true){ + System.out.print("Enter something: "); + String name = scanner.nextLine().toLowerCase(); + + if(name.equals("exit")){ + break; + } + + System.out.println(name); + + TopList cands = model.Generate(name); + Iterator> ci = cands.iterator(); + + int lim = Math.min(5, cands.size()); + + if(lim == 0){ + System.out.println("No candidates for this..."); + }else { + for (int i = 0; i < lim; i++) { + Pair p = ci.next(); + System.out.println(p.getFirst() + ": " + p.getSecond()); + } + } + } + } + + +} diff --git a/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/MultiExample.java b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/MultiExample.java new file mode 100644 index 000000000..f6d8ad1df --- /dev/null +++ b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/MultiExample.java @@ -0,0 +1,136 @@ +/** + * This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computation Group University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.transliteration; + +import java.util.ArrayList; +import java.util.List; + +/** + * This is meant to fulfill the same role as the Example class, but allows multiple + * targetwords to be associated with a single source word. + * + * This is intended primarily for testing on the NEWS dataset. + * Created by mayhew2 on 11/13/15. + */ +public class MultiExample { + + /** + * The word in the source language. + */ + public String sourceWord; + + /** + * The transliterated word. + */ + public List transliteratedWords; + + public List getTransliteratedWords(){ + return this.transliteratedWords; + } + + /** + * The relative important of this example. A weight of 1 means normal importance. + */ + public double weight; + + public MultiExample(String sourceWord, List transliteratedWords, double weight){ + this.sourceWord = sourceWord; + this.transliteratedWords = transliteratedWords; + this.weight = weight; + + } + + public MultiExample(String sourceWord, List transliteratedWords){ + this(sourceWord, transliteratedWords, 1); + } + + public MultiExample(String sourceWord, String transliteratedWord, double weight){ + this.sourceWord = sourceWord; + this.transliteratedWords = new ArrayList<>(); + this.transliteratedWords.add(transliteratedWord); + this.weight = weight; + } + + public MultiExample(String sourceWord, String transliteratedWord){ + this(sourceWord, transliteratedWord, 1); + } + + public void addTransliteratedWord(String tlw){ + this.transliteratedWords.add(tlw); + } + + /** + * Normalizes a Hebrew word by replacing end-form characters with their in-word equivalents. + * @param hebrewWord + * @return + */ + public static String NormalizeHebrew(String hebrewWord) { + return Program.NormalizeHebrew(hebrewWord); + } + + /** + * Removes accents from characters. + * This can be a useful fallback method if the model cannot make a prediction + * over a given word because it has not seen a particular accented character before. + * @param word + * @return + */ + public static String StripAccents(String word) { + return Program.StripAccent(word); + } + + /** + * Converts this MultiExample into a list of Examples, one for each transliterated word. Each Example has + * the same sourceWord. + * @return + */ + public List toExampleList(){ + List out = new ArrayList<>(); + + for(String t : transliteratedWords){ + out.add(new Example(this.sourceWord, t, this.weight)); + } + + return out; + } + + @Override + public String toString() { + String classname = this.getClass().getSimpleName(); + + return classname + "{" + + "sourceWord='" + sourceWord + '\'' + + ", transliteratedWords=" + transliteratedWords + + ", weight=" + weight + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + MultiExample that = (MultiExample) o; + + if (Double.compare(that.weight, weight) != 0) return false; + if (sourceWord != null ? !sourceWord.equals(that.sourceWord) : that.sourceWord != null) return false; + return transliteratedWords != null ? transliteratedWords.equals(that.transliteratedWords) : that.transliteratedWords == null; + + } + + @Override + public int hashCode() { + int result; + long temp; + result = sourceWord != null ? sourceWord.hashCode() : 0; + result = 31 * result + (transliteratedWords != null ? transliteratedWords.hashCode() : 0); + temp = Double.doubleToLongBits(weight); + result = 31 * result + (int) (temp ^ (temp >>> 32)); + return result; + } +} diff --git a/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/Production.java b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/Production.java new file mode 100644 index 000000000..cee645f48 --- /dev/null +++ b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/Production.java @@ -0,0 +1,84 @@ +/** + * This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computation Group University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.transliteration; + + +/** + * This is mainly intended to introduce clarity, not add any functionality. + * Created by mayhew2 on 11/5/15. + */ +public class Production { + + /** + * Source segment + */ + String segS; + + /** + * Target segment + */ + String segT; + + public String getFirst(){ + return segS; + } + + public String getSecond(){ + return segT; + } + + public int getOrigin(){ + return origin; + } + + public void setOrigin(int o){ + this.origin = o; + } + + /** + * Language of origin (will just be an ID, won't actually know the country) + * -1 is the default id. + */ + int origin; + + public Production(String segS, String segT) { + this(segS, segT, -1); + } + + public Production(String segS, String segT, int origin) { + this.segS = segS; + this.segT = segT; + this.origin = origin; + } + + @Override + public int hashCode() { + int result = segS.hashCode(); + result = 31 * result + segT.hashCode(); + result = 31 * result + origin; + return result; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + Production that = (Production) o; + + if (origin != that.origin) return false; + if (segS != null ? !segS.equals(that.segS) : that.segS != null) return false; + return !(segT != null ? !segT.equals(that.segT) : that.segT != null); + + } + + @Override + public String toString() { + return "Production{" + segS + " : " + segT + ", " + origin + '}'; + } +} diff --git a/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/Program.java b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/Program.java new file mode 100644 index 000000000..dbbfe02eb --- /dev/null +++ b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/Program.java @@ -0,0 +1,732 @@ +/** + * This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computation Group University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.transliteration; + +import edu.illinois.cs.cogcomp.core.datastructures.Pair; +import edu.illinois.cs.cogcomp.core.datastructures.Triple; +import edu.illinois.cs.cogcomp.core.io.LineIO; +import edu.illinois.cs.cogcomp.utils.Dictionaries; +import edu.illinois.cs.cogcomp.utils.InternDictionary; +import edu.illinois.cs.cogcomp.utils.SparseDoubleVector; +import edu.illinois.cs.cogcomp.utils.TopList; + +import java.io.FileNotFoundException; +import java.text.Normalizer; +import java.util.*; + + +import net.sourceforge.pinyin4j.PinyinHelper; + +class Program { + + static void main(String[] args) throws FileNotFoundException { + //RussianDiscoveryTest(); return; + //ChineseDiscoveryTest(); return; + HebrewDiscoveryTest(); + + + } + + public static String StripAccent(String stIn) { + String strNFD = Normalizer.normalize(stIn, Normalizer.Form.NFD); + StringBuilder sb = new StringBuilder(); + for (char ch : strNFD.toCharArray()) { + if (Character.getType(ch) != Character.NON_SPACING_MARK) { + sb.append(ch); + } + } + return sb.toString(); + + } + + + public static void HebrewDiscoveryTest() throws FileNotFoundException { + List> wordList = NormalizeHebrew(GetTabDelimitedPairs("/path/to/res/res/Hebrew/evalwords.txt")); + List> trainList = NormalizeHebrew(GetTabDelimitedPairs("/path/to/res/res/Hebrew/train_EnglishHebrew.txt")); + List> trainList2 = NormalizeHebrew(GetTabDelimitedPairs("/path/to/WikiTransliteration/Aliases/heExamples.txt")); + //List> trainList2 = RemoveVeryLong(NormalizeHebrew(GetTabDelimitedPairs(@"C:\Data\WikiTransliteration\Aliases\heExamples-Lax.txt")), 20); + + List> wordAndTrain = new ArrayList<>(wordList); + wordAndTrain.addAll(trainList); + + HashMap usedExamples = new HashMap<>(); + for (Pair pair : wordAndTrain) { + usedExamples.put(pair.getFirst(), true); + } + //trainList2 = TruncateList(trainList2, 2000); + for (Pair pair : trainList2) { + if (!usedExamples.containsKey(pair.getFirst())) + trainList.add(pair); + } + + //DiscoveryTestDual(RemoveDuplicates(GetListValues(wordList)), trainList, LiftPairList(wordList), 15, 15); + //TestXMLData(trainList, wordList, 15, 15); + + + List candidateList = GetListValues(wordList); + //wordList = GetRandomPartOfList(trainList, 50, 31); + candidateList.addAll(GetListValues(wordList)); + + DiscoveryEM(200, RemoveDuplicates(candidateList), trainList, LiftPairList(wordList), new CSPModel(40, 0, 0, 0.000000000000001, CSPModel.SegMode.None, false, CSPModel.SmoothMode.BySum, FallbackStrategy.NotDuringTraining, CSPModel.EMMode.Normal, false)); + //DiscoveryEM(200, RemoveDuplicates(candidateList), trainList, LiftPairList(wordList), new CSPModel(40, 0, 0, 0, FallbackStrategy.Standard)); + + //DiscoveryTestDual(RemoveDuplicates(candidateList), trainList, LiftPairList(wordList), 40, 40); + //DiscoveryTest(RemoveDuplicates(candidateList), trainList, LiftPairList(wordList), 40, 40); + } + + public static void ChineseDiscoveryTest() throws FileNotFoundException { + + //List> trainList = CharifyTargetWords(GetTabDelimitedPairs(@"C:\Users\jpaster2\Desktop\res\res\Chinese\chinese_full"),chMap); + List> trainList = UndotTargetWords(GetTabDelimitedPairs("/path/to/res/res/Chinese/chinese_full")); + List> wordList = GetRandomPartOfList(trainList, 700, 123); + + //StreamWriter writer = new StreamWriter(@"C:\Users\jpaster2\Desktop\res\res\Chinese\chinese_test_pairs.txt"); + //for (Pair pair in wordList) + // writer.WriteLine(pair.Key + "\t" + pair.Value); + + //writer.Close(); + + List candidates = GetListValues(wordList); + //wordList.RemoveRange(600, 100); + wordList = wordList.subList(0, 600); + + //DiscoveryTestDual(RemoveDuplicates(candidates), trainList, LiftPairList(wordList), 15, 15); + DiscoveryEM(200, RemoveDuplicates(candidates), trainList, LiftPairList(wordList), new CSPModel(40, 0, 0, 0.000000000000001, CSPModel.SegMode.Entropy, false, CSPModel.SmoothMode.BySource, FallbackStrategy.Standard, CSPModel.EMMode.Normal, false)); + //TestXMLData(trainList, wordList, 15, 15); + + + } + + public static void RussianDiscoveryTest() throws FileNotFoundException { + + List candidateList = LineIO.read("/path/to/res/res/Russian/RussianWords"); + for (int i = 0; i < candidateList.size(); i++) + candidateList.set(i, candidateList.get(i).toLowerCase()); + //candidateList.Clear(); + + //HashMap> evalList = GetAlexData(@"C:\Users\jpaster2\Desktop\res\res\Russian\evalpairs.txt");//@"C:\Users\jpaster2\Desktop\res\res\Russian\evalpairs.txt"); + HashMap> evalList = GetAlexData("/path/to/res/res/Russian/evalpairsShort.txt");//@"C:\Users\jpaster2\Desktop\res\res\Russian\evalpairs.txt"); + + //List> trainList = NormalizeHebrew(GetTabDelimitedPairs(@"C:\Users\jpaster2\Desktop\res\res\Hebrew\train_EnglishHebrew.txt")); + List> trainList = GetTabDelimitedPairs("/path/to/WikiTransliteration/Aliases/ruExamples.txt"); + //List> trainList2 = RemoveVeryLong(NormalizeHebrew(GetTabDelimitedPairs(@"C:\Data\WikiTransliteration\Aliases\heExamples-Lax.txt")), 20); + List> trainList2 = new ArrayList<>(trainList.size()); + + HashMap usedExamples = new HashMap<>(); + for (String s : evalList.keySet()) + usedExamples.put(s, true); + + //trainList2 = TruncateList(trainList2, 2000); + for (Pair pair : trainList) + if (!usedExamples.containsKey(pair.getFirst())) trainList2.add(pair); + + DiscoveryEM(200, RemoveDuplicates(GetWords(evalList)), trainList2, evalList, new CSPModel(40, 0, 0, 0.000000000000001, CSPModel.SegMode.Entropy, false, CSPModel.SmoothMode.BySource, FallbackStrategy.Standard, CSPModel.EMMode.Normal, false)); + //DiscoveryTestDual(RemoveDuplicates(candidateList), trainList2, evalList, 15, 15); + //DiscoveryTestDual(RemoveDuplicates(GetWords(evalList)), trainList2, evalList, 15, 15); + + } + + public static List GetWords(HashMap> dict) { + List result = new ArrayList<>(); + for (List list : dict.values()) + result.addAll(list); + return result; + } + + public static HashMap> LiftPairList(List> list) { + HashMap> result = new HashMap<>(list.size()); + for (Pair pair : list) { + ArrayList tmp = new ArrayList<>(); + tmp.add(pair.getSecond()); + result.put(pair.getFirst(), tmp); + } + + return result; + } + + private static List> UndotTargetWords(List> list) { + List> result = new ArrayList<>(list.size()); + for (Pair pair : list) { + result.add(new Pair<>(pair.getFirst(), pair.getSecond().replace(".", ""))); + } + + return result; + } + + + private static List RemoveDuplicates(List list) { + List result = new ArrayList<>(list.size()); + HashMap seen = new HashMap<>(); + for (String s : list) { + if (seen.containsKey(s)) continue; + seen.put(s, true); + result.add(s); + } + + return result; + } + + private static List> ConvertExamples(List> examples) { + List> fExamples = new ArrayList<>(examples.size()); + for (Pair pair : examples) + fExamples.add(new Triple<>(pair.getFirst(), pair.getSecond(), 1.0)); + + return fExamples; + } + + /** + * Calculates a probability table for P(String2 | String1) + * This normalizes by source counts for each production. + *

+ * FIXME: This is identical to... WikiTransliteration.NormalizeBySourceSubstring + * + * @param ? + * @return + */ + public static HashMap PSecondGivenFirst(HashMap counts) { + // counts of first words in productions. + HashMap totals1 = WikiTransliteration.GetAlignmentTotals1(counts); + + HashMap result = new HashMap<>(counts.size()); + for (Production prod : counts.keySet()) // loop over all productions + { + double prodcount = counts.get(prod); + double sourcecounts = totals1.get(prod.getFirst()); // be careful of unboxing! + double value = sourcecounts == 0 ? 0 : (prodcount / sourcecounts); + result.put(prod, value); + } + + return result; + } + + + private static HashMap SumNormalize(HashMap vector) { + HashMap result = new HashMap<>(vector.size()); + double sum = 0; + for (double value : vector.values()) { + sum += value; + } + + for (Production key : vector.keySet()) { + Double value = vector.get(key); + result.put(key, value / sum); + } + + return result; + } + + + /** + * Does this not need to have access to ngramsize? No. It gets all ngrams so it can backoff. + *

+ * By default, this includes padding. + * + * @param examples + * @param maxSubstringLength + * @return + */ + public static HashMap GetNgramCounts(List examples, int maxSubstringLength) { + return WikiTransliteration.GetNgramCounts(1, maxSubstringLength, examples, true); + } + + /** + * Given a wikidata file, this gets all the words in the foreign language for the language model. + * + * @param fname + * @return + */ + public static List getForeignWords(String fname) throws FileNotFoundException { + List lines = LineIO.read(fname); + List words = new ArrayList<>(); + + for (String line : lines) { + String[] parts = line.trim().split("\t"); + String foreign = parts[0]; + + String[] fsplit = foreign.split(" "); + for (String word : fsplit) { + words.add(word); + } + } + + return words; + } + + /** + * Given a wikidata file, this gets all the words in the foreign language for the language model. + * + * @return + */ + public static List getForeignWords(List examples) throws FileNotFoundException { + + List words = new ArrayList<>(); + + for (Example e : examples) { + words.add(e.getTransliteratedWord()); + } + + return words; + } + + + public static List> GetTabDelimitedPairs(String filename) throws FileNotFoundException { + List> result = new ArrayList<>(); + + for (String line : LineIO.read(filename)) { + String[] pair = line.trim().split("\t"); + if (pair.length != 2) continue; + result.add(new Pair<>(pair[0].trim().toLowerCase(), pair[1].trim().toLowerCase())); + } + + return result; + } + + + static HashMap> maxCache = new HashMap<>(); + + /** + * This returns a map of productions to counts. These are counts over the entire training corpus. These are all possible + * productions seen in training data. If a production does not show up in training, it will not be seen here. + *

+ * normalization parameter decides if it is normalized (typically not). + * + * @param maxSubstringLength1 + * @param maxSubstringLength2 + * @param examples + * @param probs + * @param weightingMode + * @param normalization + * @param getExampleCounts + * @return + */ + static HashMap MakeRawAlignmentTable(int maxSubstringLength1, int maxSubstringLength2, List> examples, HashMap probs, WeightingMode weightingMode, WikiTransliteration.NormalizationMode normalization, boolean getExampleCounts) { + InternDictionary internTable = new InternDictionary<>(); + HashMap counts = new HashMap<>(); + + List>> exampleCounts = (getExampleCounts ? new ArrayList>>(examples.size()) : null); + + int alignmentCount = 0; + for (Triple example : examples) { + String sourceWord = example.getFirst(); + String bestWord = example.getSecond(); // bestWord? Shouldn't it be target word? + if (sourceWord.length() * maxSubstringLength2 >= bestWord.length() && bestWord.length() * maxSubstringLength1 >= sourceWord.length()) { + alignmentCount++; + + HashMap wordCounts; + + if (weightingMode == WeightingMode.FindWeighted && probs != null) + wordCounts = WikiTransliteration.FindWeightedAlignments(sourceWord, bestWord, maxSubstringLength1, maxSubstringLength2, probs, internTable, normalization); + //wordCounts = WikiTransliteration.FindWeightedAlignmentsAverage(sourceWord, bestWord, maxSubstringLength1, maxSubstringLength2, probs, internTable, true, normalization); + else if (weightingMode == WeightingMode.CountWeighted) + wordCounts = WikiTransliteration.CountWeightedAlignments(sourceWord, bestWord, maxSubstringLength1, maxSubstringLength2, probs, internTable, normalization, false); + else if (weightingMode == WeightingMode.MaxAlignment) { + + HashMap cached = new HashMap<>(); + Production p = new Production(sourceWord, bestWord); + if (maxCache.containsKey(p)) { + cached = maxCache.get(p); + } + + Dictionaries.AddTo(probs, cached, -1.); + + wordCounts = WikiTransliteration.CountMaxAlignments(sourceWord, bestWord, maxSubstringLength1, probs, internTable, false); + maxCache.put(new Production(sourceWord, bestWord), wordCounts); + + Dictionaries.AddTo(probs, cached, 1); + } else if (weightingMode == WeightingMode.MaxAlignmentWeighted) + wordCounts = WikiTransliteration.CountMaxAlignments(sourceWord, bestWord, maxSubstringLength1, probs, internTable, true); + else {//if (weightingMode == WeightingMode.None || weightingMode == WeightingMode.SuperficiallyWeighted) + // This executes if probs is null + wordCounts = WikiTransliteration.FindAlignments(sourceWord, bestWord, maxSubstringLength1, maxSubstringLength2, internTable, normalization); + } + + if (weightingMode == WeightingMode.SuperficiallyWeighted && probs != null) { + wordCounts = SumNormalize(Dictionaries.MultiplyDouble(wordCounts, probs)); + } + + Dictionaries.AddTo(counts, wordCounts, example.getThird()); + + if (getExampleCounts) { + List> curExampleCounts = new ArrayList<>(wordCounts.size()); + for (Production key : wordCounts.keySet()) { + Double value = wordCounts.get(key); + curExampleCounts.add(new Pair<>(key, value)); + } + + exampleCounts.add(curExampleCounts); + } + } else if (getExampleCounts) { + exampleCounts.add(null); + } + } + + return counts; + } + + + public static SparseDoubleVector, String>> PSecondGivenFirst(SparseDoubleVector, String>> productionProbs) { + SparseDoubleVector, String>> result = new SparseDoubleVector<>(); + SparseDoubleVector> totals = new SparseDoubleVector<>(); + for (Pair, String> key : productionProbs.keySet()) { + Double value = productionProbs.get(key); + totals.put(key.getFirst(), totals.get(key.getFirst()) + value); + } + + for (Pair, String> key : productionProbs.keySet()) { + Double value = productionProbs.get(key); + result.put(key, value / totals.get(key.getFirst())); + } + + return result; + + } + + + /** + * Calculates a probability table for P(String2 | String1) + * + * @param counts + * @return + */ + public static SparseDoubleVector, String>> PSecondGivenFirstTriple(SparseDoubleVector, String>> counts) { + SparseDoubleVector> totals = new SparseDoubleVector<>(); + for (Pair, String> key : counts.keySet()) { + Double value = counts.get(key); + totals.put(key.getFirst(), totals.get(key.getFirst()) + value); + } + + SparseDoubleVector, String>> result = new SparseDoubleVector, String>>(counts.size()); + for (Pair, String> key : counts.keySet()) { + Double value = counts.get(key); + double total = totals.get(key.getFirst()); + result.put(key, total == 0 ? 0 : value / total); + } + + return result; + } + + /** + * Returns a random subset of a list; the list provided is modified to remove the selected items. + * + * @param wordList + * @param count + * @param seed + * @return + */ + public static List> GetRandomPartOfList(List> wordList, int count, int seed) { + Random r = new Random(seed); + + List> randomList = new ArrayList<>(count); + + for (int i = 0; i < count; i++) { + + int index = r.nextInt(wordList.size()); //r.Next(wordList.size()); + randomList.add(wordList.get(index)); + wordList.remove(index); + } + + return randomList; + } + + public static List GetListValues(List> wordList) { + List valueList = new ArrayList<>(wordList.size()); + for (Pair pair : wordList) + valueList.add(pair.getSecond()); + + return valueList; + } + + + static double Choose(double n, double k) { + double result = 1; + + for (double i = Math.max(k, n - k) + 1; i <= n; ++i) + result *= i; + + for (double i = 2; i <= Math.min(k, n - k); ++i) + result /= i; + + return result; + } + + public static double[][] SegmentationCounts(int maxLength) { + double[][] result = new double[maxLength][]; + for (int i = 0; i < maxLength; i++) { + result[i] = new double[i + 1]; + for (int j = 0; j <= i; j++) + result[i][j] = Choose(i, j); + } + + return result; + } + + public static double[][] SegSums(int maxLength) { + double[][] segmentationCounts = SegmentationCounts(maxLength); + double[][] result = new double[maxLength][]; + for (int i = 0; i < maxLength; i++) { + result[i] = new double[maxLength]; + for (int j = 0; j < maxLength; j++) { + int minIJ = Math.min(i, j); + for (int k = 0; k <= minIJ; k++) + result[i][j] += segmentationCounts[i][k] * segmentationCounts[j][k];// *Math.Pow(0.5, k + 1); + } + } + + return result; + } + + /** + * Number of possible segmentations. + */ + public static double[][] segSums = SegSums(40); + + public static void DiscoveryEM(int iterations, List candidateWords, List> trainingPairs, HashMap> testingPairs, TransliterationModel model) { + List> trainingTriples = ConvertExamples(trainingPairs); + + for (int i = 0; i < iterations; i++) { + System.out.println("Iteration #" + i); + + long startTime = System.nanoTime(); + System.out.print("Training..."); + model = model.LearnModel(trainingTriples); + long endTime = System.nanoTime(); + System.out.println("Finished in " + (startTime - endTime) / (1000000 * 1000) + " seconds."); + + DiscoveryEvaluation(testingPairs, candidateWords, model); + } + + System.out.println("Finished."); + } + + + public static String NormalizeHebrew(String word) { + word = word.replace('ן', 'נ'); + word = word.replace('ך', 'כ'); + word = word.replace('ץ', 'צ'); + word = word.replace('ם', 'מ'); + word = word.replace('ף', 'פ'); + + return word; + } + + public static List> NormalizeHebrew(List> pairs) { + List> result = new ArrayList<>(pairs.size()); + + for (int i = 0; i < pairs.size(); i++) + result.add(new Pair<>(pairs.get(i).getFirst(), NormalizeHebrew(pairs.get(i).getSecond()))); + + return result; + } + + /** + * This is still broken... + * + * @param path + * @return + * @throws FileNotFoundException + */ + static HashMap> GetAlexData(String path) throws FileNotFoundException { + + HashMap> result = new HashMap<>(); + + // TODO: this is all broken. +// ArrayList data = LineIO.read(path); +// +// for (String line : data) +// { +// if (line.length() == 0) continue; +// +// Match match = Regex.Match(line, "(?\\w+)\t(?\\w+)(?: {(?:-(?\\w*?)(?:(?:, )|}))+)?", RegexOptions.Compiled); +// +// String russianRoot = match.Groups["rroot"].Value; +// if (russianRoot.length() == 0) +// System.out.println("Parse error"); +// +// List russianList = new ArrayList<>(); +// +// //if (match.Groups["rsuf"].Captures.Count == 0) +// russianList.Add(russianRoot.toLower()); //root only +// //else +// for (Capture capture : match.Groups["rsuf"].Captures) +// russianList.add((russianRoot + capture.Value).ToLower()); +// +// result[match.Groups["eng"].Value.ToLower()] = russianList; +// +// } + + return result; + } + + + public static HashMap PruneProbs(int topK, HashMap probs) { + HashMap>> lists = new HashMap<>(); + for (Production key : probs.keySet()) { + Double value = probs.get(key); + + if (!lists.containsKey(key.getFirst())) { + lists.put(key.getFirst(), new ArrayList>()); + } + + lists.get(key.getFirst()).add(new Pair<>(key.getSecond(), value)); + } + + HashMap result = new HashMap<>(); + for (String key : lists.keySet()) { + List> value = lists.get(key); + Collections.sort(value, new Comparator>() { + @Override + public int compare(Pair o1, Pair o2) { + double v = o2.getSecond() - o1.getSecond(); + if (v > 0) { + return 1; + } else if (v < 0) { + return -1; + } else { + return 0; + } + } + }); + int toAdd = Math.min(topK, value.size()); + for (int i = 0; i < toAdd; i++) { + result.put(new Production(key, value.get(i).getFirst()), value.get(i).getSecond()); + } + } + + return result; + } + + private static void DiscoveryEvaluation(HashMap> testingPairs, List candidates, TransliterationModel model) { + int correct = 0; + //int contained = 0; + double mrr = 0; + int misses = 0; + + for (String key : testingPairs.keySet()) { + List value = testingPairs.get(key); + + //double[] scores = new double[candidates.size()]; + final List scores = new ArrayList<>(candidates.size()); + String[] words = candidates.toArray(new String[candidates.size()]); + + final List fakewords = new ArrayList<>(candidates); + + for (int i = 0; i < words.length; i++) + scores.set(i, model.GetProbability(key, words[i])); + + // sort the words according to the scores. Assume that the indices match up + //Array.Sort(scores, words); + Collections.sort(candidates, new Comparator() { + public int compare(String left, String right) { + return Double.compare(scores.get(fakewords.indexOf(left)), scores.get(fakewords.indexOf(right))); + } + }); + + int index = 0; + for (int i = words.length - 1; i >= 0; i--) + if (value.contains(words[i])) { + index = i; + break; + } + + index = words.length - index; + + if (index == 1) + correct++; + else + misses++; + mrr += ((double) 1) / index; + } + + mrr /= testingPairs.size(); + + System.out.println(testingPairs.size() + " pairs tested in total; " + candidates.size() + " candidates."); + //System.out.println(contained + " predictions contained (" + (((double)contained) / testingPairs.Count) + ")"); + System.out.println(correct + " predictions exactly correct (" + (((double) correct) / testingPairs.size()) + ")"); + System.out.println("MRR: " + mrr); + } + + public static SparseDoubleVector InitializeWithRomanization(SparseDoubleVector probs, List> trainingTriples, List testing) { + +// List hebtable; +// try { +// hebtable = LineIO.read("hebrewromanization.txt"); +// } catch (FileNotFoundException e) { +// return probs; +// } +// +// for(String line : hebtable){ +// String[] sline = line.split(" "); +// String heb = sline[0]; +// String eng = sline[1]; +// +// probs.put(new Production(eng, heb), 1.0); +// } + + // get all values from training. + for (Triple t : trainingTriples) { + String chinese = t.getSecond(); + for (char c : chinese.toCharArray()) { + // CHINESE + String[] res = PinyinHelper.toHanyuPinyinStringArray(c); + for (String s : res) { + // FIXME: strip number from s? + String ss = s.substring(0, s.length() - 1); + probs.put(new Production(ss, c + ""), 1.); + } + } + } + + // get all values from testing also + for (MultiExample t : testing) { + List chineseWords = t.getTransliteratedWords(); + for (String chinese : chineseWords) { + for (char c : chinese.toCharArray()) { + // CHINESE + String[] res = PinyinHelper.toHanyuPinyinStringArray(c); + for (String s : res) { + // FIXME: strip number from s? + String sss = s.substring(0, s.length() - 1); + probs.put(new Production(sss, c + ""), 1.); + } + } + } + } + + + return probs; + } + + + /** + * Convert each production into a set of productions with different origins + * + * @param probs + * @param numOrigins + * @return + */ + public static SparseDoubleVector SplitIntoOrigins(SparseDoubleVector probs, int numOrigins) { + + SparseDoubleVector newprobs = new SparseDoubleVector<>(); + + for (Production p : probs.keySet()) { + double prob = probs.get(p); + for (int i = 0; i < numOrigins; i++) { + Production po = new Production(p.segS, p.segT, i); + newprobs.put(po, prob / numOrigins); + } + } + + return newprobs; + } + + + public enum WeightingMode { + None, FindWeighted, SuperficiallyWeighted, CountWeighted, MaxAlignment, MaxAlignmentWeighted + } + +} diff --git a/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/Runner.java b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/Runner.java new file mode 100644 index 000000000..5b93b2d39 --- /dev/null +++ b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/Runner.java @@ -0,0 +1,763 @@ +/** + * This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computation Group University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.transliteration; + +import edu.illinois.cs.cogcomp.core.algorithms.LevensteinDistance; +import edu.illinois.cs.cogcomp.core.algorithms.ProducerConsumer; +import edu.illinois.cs.cogcomp.core.datastructures.Pair; +import edu.illinois.cs.cogcomp.core.io.LineIO; +import edu.illinois.cs.cogcomp.utils.Dictionaries; +import edu.illinois.cs.cogcomp.utils.SparseDoubleVector; +import edu.illinois.cs.cogcomp.utils.TopList; +import edu.illinois.cs.cogcomp.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.util.*; + + +class Runner { + + static String dataPath = "Data/hebrewEnglishAlignment/"; + static String wikidata = "/shared/corpora/transliteration/wikidata/"; + static String wikidataurom = "/shared/corpora/transliteration/wikidata.urom/"; + static String wikidataextra = "/shared/corpora/transliteration/wikidata-extra/"; + static String NEWS = "/shared/corpora/transliteration/NEWS2015/"; + static String tl = "/shared/corpora/transliteration/"; + static String irvinedata = tl + "from_anne_irvine/"; + + // set these later on + public static int NUMTRAIN = -1; + public static int NUMTEST = -1; + + public static double TRAINRATIO = 0.75; + + + private static Logger logger = LoggerFactory.getLogger(Runner.class); + + public static void main(String[] args) throws Exception { + + String trainfile = args[0]; + String testfile = args[1]; + + String trainlang = "Hindi"; + String testlang = "Hindi"; + String probFile = "nonsenseword"; + + String method = "wikidata"; + + if(method == "wikidata") { + TrainAndTest(trainfile, testfile); + }else if(method == "wikidata-pretrain"){ + LoadAndTest(probFile, testfile); + }else if(method == "NEWS"){ + String langpair = "EnBa"; + + TrainAndTestNEWS(langpair); + }else if(method == "CCB") { + List data = Utils.readCCBData("en", "ja"); + System.out.println(data.size()); + }else if(method == "compare"){ + compare(trainfile, testfile); + }else if(method == "test") { + test(); + }else if(method == "makedata"){ + makedata(trainfile, testfile); + }else if(method == "makeprobs") { + trainfile = String.format("wikidata.%s-%s", trainlang, testlang); + makeprobs(trainfile, trainlang, testlang); + }else if(method == "experiments"){ + experiments(); + }else{ + logger.error("Should never get here! Try a new method. It was: " + method); + } + + } + + /** + * Run this method to get all results for ranking. + * @throws Exception + */ + private static void experiments() throws Exception { + String[] arabic_names = {"Arabic", "Egyptian_Arabic", "Mazandarani", "Pashto", "Persian", "Western_Punjabi"}; + String[] devanagari_names = {"Hindi", "Marathi", "Nepali", "Sanskrit"}; + String[] cyrillic_names = {"Bashkir", "Bulgarian", "Chechen", "Kirghiz", "Macedonian", "Russian", "Ukrainian"}; + + List cyrillicresults = new ArrayList<>(); + NUMTRAIN = 381; + NUMTEST = 482; + for(String name : cyrillic_names){ + logger.debug("Working on " + name); + String trainfile = wikidata + String.format("wikidata.%s", name); + String testfile = wikidata + String.format("wikidata.%s", "Chuvash"); + + cyrillicresults.add(name + TrainAndTest(trainfile, testfile)); + } + LineIO.write("cyrillicresults.txt", cyrillicresults); + + } + + + /** + This trains a model from a file and writes the productions + to file. This is intended primarily as a way to measure WAVE. + Use this in tandem with makedata(). + */ + private static void makeprobs(String trainfile, String trainlang, String testlang) throws IOException{ + + List training = Utils.readWikiData("gen-data/" + trainfile); + + SPModel model = new SPModel(training); + + model.Train(5); + + model.WriteProbs("models/probs-" + trainlang + "-" + testlang + ".txt"); + + } + + /** + * Given two language names, this will create a file of pairs between these languages by + * finding pairs in each language with common English sources. + * + * The output of this will be used in makeprobs to get WAVE scores. + * + * @param trainfile + * @param testfile + * @throws IOException + */ + private static void makedata(String trainfile, String testfile) throws IOException { + List training = Utils.readWikiData(trainfile); + List testing = Utils.readWikiData(testfile); + + String langA = trainfile.split("\\.")[1]; + String langB = testfile.split("\\.")[1]; + + // this creates examples that map from training tgt lang to testing tgt lang + List a2b = new ArrayList<>(); + + HashMap> engToEx = new HashMap<>(); + for(Example e : training){ + String eng = e.sourceWord; + + HashSet prods; + if(engToEx.containsKey(eng)){ + prods = engToEx.get(eng); + }else { + prods = new HashSet<>(); + } + prods.add(e); + engToEx.put(eng, prods); + } + + logger.debug("Done with reading " + trainfile); + + for(Example e : testing){ + String eng = e.sourceWord; + if(engToEx.containsKey(eng)){ + HashSet examples = engToEx.get(eng); + + for(Example e2a : examples) { + // eng of e2a is same as eng of e + a2b.add(new Example(e2a.getTransliteratedWord(), e.getTransliteratedWord())); + } + + } + } + + logger.debug("Done with reading " + testfile); + + + HashSet outlines = new HashSet<>(); + + for(Example e : a2b){ + // wikidata file ordering is tgt src. + outlines.add(e.getTransliteratedWord() + "\t" + e.sourceWord); + } + + List listlines = new ArrayList<>(outlines); + listlines.add(0, "# " + langB + "\t" + langA + "\n"); + + LineIO.write("gen-data/wikidata." + langA + "-" + langB, listlines); + + } + + static void compare(String trainfile, String testfile) throws FileNotFoundException { + List training = Utils.readWikiData(trainfile); + List testing = Utils.readWikiData(testfile); + + // get transliteration pairs between these two. + HashMap english2train = new HashMap<>(); + for(Example e : training){ + if(english2train.containsKey(e.sourceWord)){ + // probably not a problem. Because the readWikiData splits the names into first and last, + // it is likely that we will see the same first name many times. Assume it is the same + // each time (weak!) + } + english2train.put(e.sourceWord,e.getTransliteratedWord()); + } + + int num = 0; + double sum = 0; + for(Example e : testing){ + if(english2train.containsKey(e.sourceWord)){ + String train = english2train.get(e.sourceWord); + // there is only one of these, so it doesn't matter if it is a list. + List test = e.getTransliteratedWords(); + + // get edit distance between these. + double F1 = Utils.GetFuzzyF1(train, test); + sum += F1; + num++; + } + } + System.out.println("Num pairs: " + num); + System.out.println("Avg F1: " + sum / num); + + + } + + + /** + * A function for testing bridge languages. This creates a transitive model. + * @throws IOException + */ + static void test() throws IOException { + // open probs.Nepali + // open probs.Arabic + SPModel t = new SPModel("probs-Nepali.txt"); + SPModel a = new SPModel("probs-Western_Punjabi.txt"); + + SparseDoubleVector tprobs = t.getProbs(); + SparseDoubleVector aprobs = a.getProbs(); + + // this creates productions that map from Nepali to arabic. + SparseDoubleVector t2a = new SparseDoubleVector<>(); + + // this maps from English segment to Nepali segment set. This set should add to 1 + + HashMap> firstToProd = new HashMap<>(); + for(Production p : tprobs.keySet()){ + String eng = p.getFirst(); + // don't include the tiny ones... + if(tprobs.get(p) < 0.1){ + continue; + } + HashSet prods; + if(firstToProd.containsKey(eng)){ + prods = firstToProd.get(eng); + }else { + prods = new HashSet<>(); + } + prods.add(p); + firstToProd.put(eng, prods); + } + + logger.debug("Done with reading nepali..."); + + for(Production p : aprobs.keySet()){ + String eng = p.getFirst(); + if(firstToProd.containsKey(eng)){ + HashSet telugu_prods = firstToProd.get(eng); + + for(Production tel : telugu_prods) { + // eng of tel is same as eng of p + + // maybe should be log probs. + double score = tprobs.get(tel) * aprobs.get(p); + + t2a.put(new Production(tel.getSecond(), p.getSecond()), score); + } + + } + } + + logger.debug("Done with reading arabic..."); + + + double threshold = 0.; + ArrayList outlines = new ArrayList<>(); + for(Production p : t2a.keySet()){ + if(t2a.get(p) > threshold) { + outlines.add(p.getFirst() + "\t" + p.getSecond() + "\t" + t2a.get(p)); + } + } + LineIO.write("probs-nepali-to-arabic.txt", outlines); + + } + + + + + /** + * Given a model and set of testing examples, this will get scores using the generation method. + * @param model this needs to be a trained model. + * @param testing a set of examples. + * @return an 3-element double array of scores with elements MRR,ACC,F1 + * @throws Exception + */ + public static double[] TestGenerate(SPModel model, List testing, String lang) throws Exception { + double correctmrr = 0; + double correctacc = 0; + double totalf1 = 0; + + List outlines = new ArrayList<>(); + + //model.setMaxCandidates(30); + + int i = 0; + for (MultiExample example : testing) { + if(i%500 == 0) { + logger.debug("on example " + i + " out of " + testing.size()); + //logger.debug("USING THE CREATED MODEL TO GET INTO URDU."); + } + i++; + + outlines.add("SourceWord: " + example.sourceWord + ""); + for(String tw : example.getTransliteratedWords()){ + outlines.add("TransliteratedWords: " + tw); + } + + TopList prediction = model.Generate(example.sourceWord); + + for(Pair cand : prediction){ + if(example.getTransliteratedWords().contains(cand.getSecond())){ + outlines.add("**" + cand.getSecond() + ", " + cand.getFirst() + "**"); + }else{ + outlines.add("" + cand.getSecond() + ", " + cand.getFirst() + ""); + } + } + outlines.add("\n"); + + int bestindex = -1; + + double F1 = 0; + if(prediction.size() == 0){ + //logger.error("No cands for this word: " + example.sourceWord); + }else { + F1 = Utils.GetFuzzyF1(prediction.getFirst().getSecond(), example.getTransliteratedWords()); + } + totalf1 += F1; + + for(String target : example.getTransliteratedWords()){ + int index = prediction.indexOf(target); + if(bestindex == -1 || index < bestindex){ + bestindex = index; + } + } + + if (bestindex >= 0) { + correctmrr += 1.0 / (bestindex + 1); + if(bestindex == 0){ + correctacc += 1.0; + } + } + } + + LineIO.write("output/out-gen-"+ lang +".txt", outlines); + + double mrr = correctmrr / (double)testing.size(); + double acc = correctacc / (double)testing.size(); + double f1 = totalf1 / (double)testing.size(); + + double[] res = new double[3]; + res[0] = mrr; + res[1] = acc; + res[2] = f1; + + return res; + } + + /** + * Given a model and set of testing examples, this will get scores using the generation method. + * @param model this needs to be a trained model. + * @param testing a set of examples. + * @return an 3-element double array of scores with elements MRR,ACC,F1 + * @throws Exception + */ + public static double[] TestGenerateChain(SPModel model, List testing, String lang) throws Exception { + double correctmrr = 0; + double correctacc = 0; + double totalf1 = 0; + + List outlines = new ArrayList<>(); + + //String id = "Any-Arabic; NFD"; + //Transliterator t = Transliterator.getInstance(id); + + logger.warn("CREATING A SECOND STAGE MODEL RIGHT HERE."); + boolean fix = false; // don't try to fix the data... edit distance is weird in 2 foreign langs. + List training = Utils.readWikiData("gen-data/wikidata.Western_Punjabi-Urdu", fix); + logger.debug("Size of intermediate model: " + training.size()); + SPModel stage2model = new SPModel(training); + stage2model.setMaxCandidates(5); + stage2model.Train(5); + + + // FIXME: CAREFUL HERE!!! + //logger.warn("SETTING MAXCANDS TO JUST 5"); + model.setMaxCandidates(5); + + int i = 0; + for (MultiExample example : testing) { + if(i%500 == 0) { + logger.debug("on example " + i + " out of " + testing.size()); + //logger.debug("USING THE CREATED MODEL TO GET INTO URDU."); + } + i++; + + outlines.add("SourceWord: " + example.sourceWord + ""); + for(String tw : example.getTransliteratedWords()){ + outlines.add("TransliteratedWords: " + tw); + } + + TopList prediction = model.Generate(example.sourceWord); + + // This block is for the second stage in the pipeline. + TopList scriptpreds = new TopList<>(25); + // there will be 5 of these + for(Pair cand : prediction){ + + // there will be 5 of these + TopList chuvashcands = stage2model.Generate(cand.getSecond()); + + for(Pair chaincand : chuvashcands){ + scriptpreds.add(cand.getFirst() * chaincand.getFirst(), chaincand.getSecond()); + } + + } + prediction = scriptpreds; + + for(Pair cand : prediction){ + if(example.getTransliteratedWords().contains(cand.getSecond())){ + outlines.add("**" + cand.getSecond() + ", " + cand.getFirst() + "**"); + }else{ + outlines.add("" + cand.getSecond() + ", " + cand.getFirst() + ""); + } + } + outlines.add("\n"); + + int bestindex = -1; + + double F1 = 0; + if(prediction.size() == 0){ + //logger.error("No cands for this word: " + example.sourceWord); + }else { + F1 = Utils.GetFuzzyF1(prediction.getFirst().getSecond(), example.getTransliteratedWords()); + } + totalf1 += F1; + + for(String target : example.getTransliteratedWords()){ + int index = prediction.indexOf(target); + if(bestindex == -1 || index < bestindex){ + bestindex = index; + } + } + + if (bestindex >= 0) { + correctmrr += 1.0 / (bestindex + 1); + if(bestindex == 0){ + correctacc += 1.0; + } + } + } + + LineIO.write("output/out-gen-"+ lang +".txt", outlines); + + double mrr = correctmrr / (double)testing.size(); + double acc = correctacc / (double)testing.size(); + double f1 = totalf1 / (double)testing.size(); + + double[] res = new double[3]; + res[0] = mrr; + res[1] = acc; + res[2] = f1; + + return res; + } + + + + public static Pair TestDiscovery(SPModel model, List testing) throws IOException { + double correctmrr = 0; + double correctacc = 0; + + List possibilities = new ArrayList<>(); + for(Example e : testing){ + possibilities.add(e.getTransliteratedWord()); + } + + List outlines = new ArrayList<>(); + + for (Example example : testing) { + + int topK = 30; + TopList ll = new TopList<>(topK); + for(String target : possibilities){ + double prob = model.Probability(example.sourceWord, target); + ll.add(prob, target); + } + + outlines.add(example.sourceWord); + for(Pair p : ll){ + String s = p.getSecond(); + + if(s.equals(example.getTransliteratedWord())){ + s = "**" + s + "**"; + } + + outlines.add(s); + } + outlines.add(""); + + int index = ll.indexOf(example.getTransliteratedWord()); + if (index >= 0) { + correctmrr += 1.0 / (index + 1); + if(index == 0){ + correctacc += 1.0; + } + } + } + + LineIO.write("output/out-disc.txt", outlines); + + double mrr = correctmrr / (double)testing.size(); + double acc = correctacc / (double)testing.size(); + + return new Pair<>(mrr, acc); + } + + /** + * This loads a prob file (having been generated from a previous testing run) and tests + * on the test file. + * @param probFile + * @param testfile + * @throws Exception + */ + public static void LoadAndTest(String probFile, String testfile) throws Exception { + + List testing = Utils.readWikiData(testfile); + java.util.Collections.shuffle(testing); + + double avgmrr= 0; + double avgacc = 0; + double avgf1 = 0; + + SPModel model = new SPModel(probFile); + + logger.info("Testing."); + double[] res = TestGenerate(model, testing, "unknown"); + double mrr = res[0]; + double acc = res[1]; + double f1 = res[2]; + avgmrr += mrr; + avgacc += acc; + avgf1 += f1; + model.WriteProbs("models/probs-" + testfile.split("\\.")[1] +".txt"); + + //System.out.println("============="); + //System.out.println("AVGMRR=" + avgmrr); + //System.out.println("AVGACC=" + avgacc); + //System.out.println("AVGF1 =" + avgf1); + } + + + public static String TrainAndTest(String trainfile, String testfile) throws Exception { + + List subtraining; + List subtesting; + + List training = Utils.readWikiData(trainfile); + logger.info(String.format("Loaded %s training examples.", training.size())); + + List langstrings = Program.getForeignWords(training); + + if(trainfile.equals(testfile)){ + logger.info("Train and test are the same... using just train."); + // then don't need to load testing. + int trainnum = (int)Math.round(TRAINRATIO*training.size()); + subtraining = training.subList(0,trainnum); + subtesting = training.subList(trainnum,training.size()-1); + }else{ + logger.debug("Train and test are different."); + // OK, load testing. + List testing = Utils.readWikiData(testfile); + logger.info(String.format("Loaded %s testing examples", testing.size())); + + // now make sure there is no overlap between test and train. + // only keep those training examples that are not also in test. + subtraining = new ArrayList<>(); + for(Example e : training){ + if(!testing.contains(e)){ + subtraining.add(e); + } + } + + logger.info("After filtering, num training is: " + subtraining.size()); + + + subtesting = testing; + } + + logger.info("Actual Training: " + subtraining.size()); + logger.info("Actual Testing: " + subtesting.size()); + + + // params + int emiterations = 5; + boolean rom = false; // use romanization or not. + + double avgmrr= 0; + double avgacc = 0; + double avgf1 = 0; + int num = 1; + + for (int i = 0; i < num; i++) { + + java.util.Collections.shuffle(subtraining); + + SPModel model = new SPModel(subtraining); + + + //model.setUseNPLM(true); + //model.setNPLMfile("lm/newar/nplm-new.txt"); + + //model.SetLanguageModel(langstrings); +// model.setNgramSize(2); + + model.setMaxCandidates(25); + + model.Train(emiterations, rom, subtesting); + +// Pair p = TestDiscovery(model, testing); +// double mrr = p.getFirst(); +// double acc = p.getSecond(); + logger.info("Testing."); + String[] pathsplit = trainfile.split("\\."); + String trainlang = pathsplit[pathsplit.length-1]; // get the last element, filename should be wikidata.Lang + + // This is for testing + + double[] res = TestGenerate(model, subtesting, trainlang); + double mrr = res[0]; + double acc = res[1]; + double f1 = res[2]; + avgmrr += mrr; + avgacc += acc; + avgf1 += f1; + System.out.println(subtraining.size() + "," + subtesting.size() + "," + mrr + "," + acc + "," + f1); + model.WriteProbs("models/probs-" + trainlang +".txt"); + } + + //System.out.println("============="); + //System.out.println("AVGMRR=" + avgmrr / num); + //System.out.println("AVGACC=" + avgacc / num); + //System.out.println("AVGF1 =" + avgf1 / num); + + //System.out.println("& " + avgmrr / num + " & " + avgacc / num + " & " + avgf1 / num + " \\\\"); + return " & " + avgmrr / num + " & " + avgacc / num + " & " + avgf1 / num + " \\\\"; + + } + + + + + + /** + * This is for training and testing of NEWS data. + + * @throws Exception + */ + public static void TrainAndTestNEWS(String langpair) throws Exception { + + // given a language pair (such as EnHi), we find the files that match it. Train is + // always for training, dev always for testing. + + String[] folders = {"NEWS2015_I2R","NEWS2015_MSRI","NEWS2015_NECTEC","NEWS2015_RMIT"}; + + String trainfile = ""; + String testfile = ""; + + for(String folder : folders){ + File filefolder = new File(NEWS + folder); + File[] files = filefolder.listFiles(); + for(File f : files){ + System.out.println(f.getName()); + String name = f.getName(); + if(name.contains("NEWS15_dev_" + langpair)){ + testfile = f.getAbsolutePath(); + }else if(name.contains("NEWS15_train_" + langpair)){ + trainfile = f.getAbsolutePath(); + } + } + } + + logger.debug("Using train: {}", trainfile); + logger.debug("Using test: {}", testfile); + + List trainingMulti = Utils.readNEWSData(trainfile); + + // convert the MultiExamples into single examples. + List training = Utils.convertMulti(trainingMulti); + + //logger.debug("USING A SHORTENED NUMBER OF TRAINING EXAMPLES!"); + //training = training.subList(0,700); + + List testing = Utils.readNEWSData(testfile); + + System.out.println("Training examples: " + training.size()); + System.out.println("Testing examples: " + testing.size()); + + double avgmrr= 0; + double avgacc = 0; + double avgf1 = 0; + int num = 1; + + for (int i = 0; i < num; i++) { + + //java.util.Collections.shuffle(training); + //java.util.Collections.shuffle(testing); + + SPModel model = new SPModel(training); + //List langstrings = Program.getForeignWords(wikidata + "wikidata.Hebrew"); + //List langstrings = Program.getForeignWords(training); + //List langstrings = LineIO.read("Data/heWords.txt"); + //model.SetLanguageModel(langstrings); + //model.SetLanguageModel(Utils.readSRILM("lm/lm-he.txt")); + //model.setUseNPLM(false); + + // This is set by the shared task. + model.setMaxCandidates(50); + model.setNgramSize(3); + + int emiterations = 5; + + logger.info("Training with " + emiterations + " iterations."); + model.Train(emiterations); + + logger.info("Testing."); + double[] res = TestGenerate(model, testing, langpair); + double mrr = res[0]; + double acc = res[1]; + double f1 = res[2]; + avgmrr += mrr; + avgacc += acc; + avgf1 += f1; + model.WriteProbs("probs.txt", 0.1); + } + + System.out.println("============="); + System.out.println("AVGMRR=" + avgmrr / num); + System.out.println("AVGACC=" + avgacc / num); + System.out.println("AVGF1 =" + avgf1 / num); + } + +} + diff --git a/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/SPModel.java b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/SPModel.java new file mode 100644 index 000000000..fdec89e95 --- /dev/null +++ b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/SPModel.java @@ -0,0 +1,373 @@ +/** + * This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computation Group University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.transliteration; + +import edu.illinois.cs.cogcomp.core.datastructures.Pair; +import edu.illinois.cs.cogcomp.core.datastructures.Triple; +import edu.illinois.cs.cogcomp.core.io.LineIO; +import edu.illinois.cs.cogcomp.utils.SparseDoubleVector; +import edu.illinois.cs.cogcomp.utils.TopList; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +import java.io.*; +import java.util.*; + +/** + *

Segmentation-Production Model for generating and discovery transliterations. + * Generation is the process of creating the transliteration of a word in a target language given the word in the source language. + * Discovery is the process of identifying a transliteration from a list of candidate words in the target language (this + * is facilitated here by obtaining the probability P(T|S)).

+ * + *

+ * One useful idea (used in our paper) is that you can find Sqrt(P(T|S)*P(S|T)) rather than just P(T|S) alone, since + * you don't necessarily know a priori in what direction the word was originally translated (e.g. from English to Russian or Russian to English?). + * In our discovery experiments this geometric mean performed substantially better, although of course your results may vary. + *

+ */ +public class SPModel + { + private Logger logger = LoggerFactory.getLogger(SPModel.class); + double minProductionProbability = 0.000000000000001; + int maxSubstringLength1 = 4; + int maxSubstringLength2 = 4; + + /** + * How many origins do we expect there to be? This should probably be about 50 or less. + */ + public static final int numOrigins = 1; + + private HashMap languageModel = null; + //private Dictionary languageModelDual=null; + + private List> trainingExamples; + //private List> trainingExamplesDual; + + /** + * This is the production probabilities table. + */ + private SparseDoubleVector probs = null; + + public SparseDoubleVector getProbs(){ + return probs; + } + + /** + * This is the production probabilities table multiplied by the segmentfactor. + */ + private SparseDoubleVector multiprobs = null; + + private HashMap pruned = null; // used to be a SparseDoubleVector + //private SparseDoubleVector> probsDual = null; + //private SparseDoubleVector> prunedDual = null; + private HashMap> probMap = null; + //private Map probMapDual = null; + + /** + * The maximum number of candidate transliterations returned by Generate, as well as preserved in intermediate steps. + * Larger values will possibly yield better transliterations, but will be more computationally expensive. + * The default value is 100. + */ + private int maxCandidates=100; + + public int getMaxCandidates(){ + return this.maxCandidates; + } + + public void setMaxCandidates(int value){ + this.maxCandidates = value; + } + + /** + * The (0,1] segment factor determines the preference for longer segments by effectively penalizing the total number of segments used. + * The probability of a transliteration is multiplied by SegmentFactor^[#segments]. Lower values prefer longer segments. A value of 0.5, the default, is better for generation, + * since productions from longer segments are more likely to be exactly correct; conversely, a value of 1 is better for discrimination. + * Set SegmentFactor before training begins. Setting it after this time effectively changes the model while keeping the same parameters, which isn't a good idea. + */ + private double segmentFactor = 0.5; + + public double getSegmentFactor(){ + return segmentFactor; + } + + public void setSegmentFactor(double value){ + this.segmentFactor = value; + } + + /** + * The size of the ngrams used by the language model (default: 4). Irrelevant if no language model is created with SetLanguageModel. + */ + private int ngramSize = 4; + + public int getNgramSize(){ + return this.ngramSize; + } + + public void setNgramSize(int value){ + this.ngramSize = value; + } + + /** + * This writes the model out to file. + * @param writer + * @param table + * @throws IOException + */ + private void WriteToWriter(DataOutputStream writer, SparseDoubleVector> table) throws IOException { + if (table == null) { + writer.write(-1); + } + else + { + writer.write(table.size()); + //for (KeyValuePair, Double> entry : table) + for(Pair key : table.keySet()) + { + Double value = table.get(key); + writer.writeChars(key.getFirst()); + writer.writeChars(key.getSecond()); + writer.writeDouble(value); + } + } + } + + private SparseDoubleVector> ReadTableFromReader(DataInputStream reader) throws IOException { + int count = reader.readInt(); + if (count == -1) return null; + SparseDoubleVector> table = new SparseDoubleVector<>(count); + for (int i = 0; i < count; i++) + { + // FIXME: this is readChar, probably should be readString?? + Pair keyPair = new Pair<>(reader.readChar() + "", reader.readChar() + ""); + table.put(keyPair, reader.readDouble()); + } + + return table; + } + + /** + * This writes the production probabilities out to file in human-readable format. + * @param fname the name of the output file + * @param threshold only write probs above this threshold + * @throws IOException + */ + public void WriteProbs(String fname, double threshold) throws IOException { + ArrayList outlines = new ArrayList<>(); + + List keys = new ArrayList<>(probs.keySet()); + Collections.sort(keys, new Comparator() { + @Override + public int compare(Production o1, Production o2) { + return o1.getFirst().compareTo(o2.getFirst()); + } + }); + for(Production t : keys){ + if(probs.get(t) > threshold) { + String tstr = t.getFirst() + "\t" + t.getSecond(); + outlines.add(tstr + "\t" + probs.get(t)); + } + } + LineIO.write(fname, outlines); + } + + /** + * This just calls WriteProbs(fname, threshold) with threshold of 0. + * @param fname the name of the output file. + */ + public void WriteProbs(String fname) throws IOException { + WriteProbs(fname, 0.0); + } + + /** + * This is basically the reverse of the WriteProbs function. + */ + public void ReadProbs(String fname) throws FileNotFoundException { + List lines = LineIO.read(fname); + + probs = new SparseDoubleVector<>(); + + for(String line : lines){ + String[] sline = line.trim().split("\t"); + probs.put(new Production(sline[0], sline[1]), Double.parseDouble(sline[2])); + } + } + + + /** + * This reads a model from file. + */ + public SPModel(String fname) throws IOException { + if(probs != null){ + probs.clear(); + } + ReadProbs(fname); + } + + /** + * Creates a new model for generating transliterations (creating new words in the target language from a word in the source language). + * Remember to Train() the model before calling Generate(). + * @param examples The training examples to learn from. + */ + public SPModel(Collection examples) + { + trainingExamples = new ArrayList<>(examples.size()); + for(Example example : examples) { + // Default weight of examples is 1 + trainingExamples.add(example.Triple()); + } + } + + /** + * + * Creates and sets a simple (ngram) language model to use when generating transliterations. + * The transliterations will then have probability == P(T|S)*P(T), where P(T) is the language model. + * In principle this allows you to leverage the presumably vast number of words available in the target language, + * although practical results may vary. + * @param targetLanguageExamples Example words from the target language you're transliterating into. + */ + public void SetLanguageModel(List targetLanguageExamples) + { + logger.info("Setting language model with " + targetLanguageExamples.size() + " words."); + languageModel = Program.GetNgramCounts(targetLanguageExamples, maxSubstringLength2); + } + + /** + * Creates the language model directly. This was intended for use with SRILM. + * @param lm + */ + public void SetLanguageModel(HashMap lm){ + logger.info("Setting language model with hashmap directly."); + languageModel = lm; + } + + + /** + * This initializes with rom=false and testing = empty list. + */ + public void Train(int emIterations){ + Train(emIterations, false, new ArrayList()); + } + + /** + * Trains the model for the specified number of iterations. + * @param emIterations The number of iterations to train for. + * @param testing + */ + public void Train(int emIterations, boolean rom, List testing) + { + List> trainingTriples = trainingExamples; + + pruned = null; + multiprobs = null; + + // this is the initialization. + if (probs == null) + { + + // FIXME: out variables... is exampleCounts actually used anywhere??? + //List, Double>>> exampleCounts = new ArrayList<>(); + probs = new SparseDoubleVector<>(Program.MakeRawAlignmentTable(maxSubstringLength1, maxSubstringLength2, trainingTriples, null, Program.WeightingMode.None, WikiTransliteration.NormalizationMode.None, false)); + + boolean getExampleCounts = false; + // gets counts of productions, not normalized. + probs = new SparseDoubleVector<>(Program.MakeRawAlignmentTable(maxSubstringLength1, maxSubstringLength2, + trainingTriples, null, Program.WeightingMode.None, WikiTransliteration.NormalizationMode.None, getExampleCounts)); + + // this just normalizes by the source string. + probs = new SparseDoubleVector<>(Program.PSecondGivenFirst(probs)); + + // FIXME: uniform origin initialization? + probs = Program.SplitIntoOrigins(probs, this.numOrigins); + + if(rom) { + probs = Program.InitializeWithRomanization(probs, trainingTriples, testing); + } + + } + + for (int i = 0; i < emIterations; i++) + { + logger.info("Training, iteration=" + (i+1)); + boolean getExampleCounts = true; + // Difference is Weighting mode. + probs = new SparseDoubleVector<>(Program.MakeRawAlignmentTable(maxSubstringLength1, maxSubstringLength2, + trainingTriples, segmentFactor != 1 ? probs.multiply(segmentFactor) : probs, Program.WeightingMode.CountWeighted, WikiTransliteration.NormalizationMode.None, getExampleCounts)); + + // this just normalizes by the source string. + probs = new SparseDoubleVector<>(Program.PSecondGivenFirst(probs)); + } + } + + /** + * Calculates the probability P(T|S), that is, the probability that transliteratedWord is a transliteration of sourceWord. + * @param sourceWord The word is the source language + * @param transliteratedWord The purported transliteration of the source word, in the target language + * @return P(T|S) + */ + public double Probability(String sourceWord, String transliteratedWord) + { + if (multiprobs==null) { + multiprobs = probs.multiply(segmentFactor); + } + + HashMap memoizationTable = new HashMap<>(); + int orig = -1; + + if(sourceWord.length() > Program.segSums.length){ + System.err.println("Sourceword is too long (length " + sourceWord.length() + "). Setting prob=0"); + return 0; + }else if(transliteratedWord.length() > Program.segSums.length){ + System.err.println("TransliterateWord is too long (length " + transliteratedWord.length() + "). Setting prob=0"); + return 0; + } + + double score = WikiTransliteration.GetSummedAlignmentProbability(sourceWord, transliteratedWord, maxSubstringLength1, maxSubstringLength2, multiprobs, memoizationTable, minProductionProbability, orig) + / Program.segSums[sourceWord.length() - 1][transliteratedWord.length() - 1]; + + return score; + } + + /** + * Generates a TopList of the most likely transliterations of the given word. + * The TopList is like a SortedList sorted most-probable to least-probable with probabilities (doubles) as keys, + * except that the keys may not be unique (multiple transliterations can be equiprobable). + * The most likely transliteration is at index 0. + * The number of transliterations returned in this manner will not exceed the maxCandidates property. + * You must train the model before generating transliterations. + * @param sourceWord The word to transliterate. + * @return A TopList containing the most likely transliterations of the word. + */ + public TopList Generate(String sourceWord) throws Exception { + if (probs == null) throw new NullPointerException("Must train at least one iteration before generating transliterations"); + + if (pruned==null) + { + multiprobs = probs.multiply(segmentFactor); + pruned = Program.PruneProbs(maxCandidates, multiprobs); + probMap = WikiTransliteration.GetProbMap(pruned); + } + + TopList result = WikiTransliteration.Predict2(maxCandidates, sourceWord, maxSubstringLength2, probMap, pruned, new HashMap>(), maxCandidates); + + + if (languageModel != null) + { + TopList fPredictions = new TopList<>(maxCandidates); + for (Pair prediction : result) { + Double prob = Math.pow(WikiTransliteration.GetLanguageProbability(prediction.getSecond(), languageModel, ngramSize), 1); + double reranked = Math.log(prediction.getFirst()) + Math.log(prob) / prediction.getSecond().length(); + + fPredictions.add(Math.exp(reranked), prediction.getSecond()); + } + result = fPredictions; + } + + return result; + } + } diff --git a/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/TransliterationModel.java b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/TransliterationModel.java new file mode 100644 index 000000000..ddf7b87a5 --- /dev/null +++ b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/TransliterationModel.java @@ -0,0 +1,19 @@ +/** + * This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computation Group University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.transliteration; + +import edu.illinois.cs.cogcomp.core.datastructures.Triple; + +import java.util.List; + +abstract class TransliterationModel { + public abstract double GetProbability(String word1, String word2); + + public abstract TransliterationModel LearnModel(List> examples); +} + diff --git a/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/WikiAlias.java b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/WikiAlias.java new file mode 100644 index 000000000..c400b1d97 --- /dev/null +++ b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/WikiAlias.java @@ -0,0 +1,25 @@ +/** + * This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computation Group University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.transliteration; + +import java.io.Serializable; + +/** + * Created by stephen on 9/24/15. + */ +public class WikiAlias implements Serializable { + public String alias; + public WikiTransliteration.AliasType type; + public int count; + + public WikiAlias(String alias, WikiTransliteration.AliasType type, int count) { + this.alias = alias; + this.type = type; + this.count = count; + } +} diff --git a/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/WikiTransliteration.java b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/WikiTransliteration.java new file mode 100644 index 000000000..33fdc3c82 --- /dev/null +++ b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/WikiTransliteration.java @@ -0,0 +1,930 @@ +/** + * This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computation Group University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.transliteration; + + +import edu.illinois.cs.cogcomp.core.datastructures.Pair; +import edu.illinois.cs.cogcomp.core.datastructures.Triple; +import edu.illinois.cs.cogcomp.utils.Dictionaries; +import edu.illinois.cs.cogcomp.utils.InternDictionary; +import edu.illinois.cs.cogcomp.utils.SparseDoubleVector; +import edu.illinois.cs.cogcomp.utils.TopList; +import org.apache.commons.lang3.StringUtils; + +import java.util.*; + +class WikiTransliteration { + + public class ContextModel { + public SparseDoubleVector, String>> productionProbs; + public SparseDoubleVector> segProbs; + public int segContextSize; + public int productionContextSize; + public int maxSubstringLength; + } + + public enum NormalizationMode { + None, + AllProductions, + BySourceSubstring, + BySourceSubstringMax, + BySourceAndTargetSubstring, + BySourceOverlap, + ByTargetSubstring + } + + // FIXME: can remove this? + public enum AliasType { + Unknown, + Link, + Redirect, + Title, + Disambig, + Interlanguage + } + + private static HashMap languageCodeTable; + + /** + * For querying wikipedia? + */ + public static final String[] languageCodes = new String[]{ + "aa", "ab", "ae", "af", "ak", "am", "an", "ar", "as", "av", "ay", "az", "ba", "be", "bg", "bh", "bi", "bm", "bn", "bo", "br", "bs", "ca", "ce", "ch", "co", "cr", "cs", "cu", "cv", "cy", "da", "de", "dv", "dz", "ee", "el", "en", "eo", "es", "et", "eu", "fa", "ff", "fi", "fj", "fo", "fr", "fy", "ga", "gd", "gl", "gn", "gu", "gv", "ha", "he", "hi", "ho", "hr", "ht", "hu", "hy", "hz", "ia", "id", "ie", "ig", "ii", "ik", "io", "is", "it", "iu", "ja", "jv", "ka", "kg", "ki", "kj", "kk", "kl", "km", "kn", "ko", "kr", "ks", "ku", "kv", "kw", "ky", "la", "lb", "lg", "li", "ln", "lo", "lt", "lu", "lv", "mg", "mh", "mi", "mk", "ml", "mn", "mr", "ms", "mt", "my", "na", "nb", "nd", "ne", "ng", "nl", "nn", "no", "nr", "nv", "ny", "oc", "oj", "om", "or", "os", "pa", "pi", "pl", "ps", "pt", "qu", "rm", "rn", "ro", "ru", "rw", "sa", "sc", "sd", "se", "sg", "sh", "si", "sk", "sl", "sm", "sn", "so", "sq", "sr", "ss", "st", "su", "sv", "sw", "ta", "te", "tg", "th", "ti", "tk", "tl", "tn", "to", "tr", "ts", "tt", "tw", "ty", "ug", "uk", "ur", "uz", "ve", "vi", "vo", "wa", "wo", "xh", "yi", "yo", "za", "zh", "zu" + }; + + public static Pair> GetAlignmentProbabilityDebug(String word1, String word2, int maxSubstringLength, HashMap probs, double minProb) { + Pair> result = GetAlignmentProbabilityDebug(word1, word2, maxSubstringLength, probs, minProb, new HashMap>>()); + return result; + } + + public static Pair> GetAlignmentProbabilityDebug(String word1, String word2, int maxSubstringLength, HashMap probs) { + return GetAlignmentProbabilityDebug(word1, word2, maxSubstringLength, probs, 0, new HashMap>>()); + } + + /** + * This used to have productions as an output variable. I (SWM) added it as the second element of return pair. + * @param word1 + * @param word2 + * @param maxSubstringLength + * @param probs + * @param floorProb + * @param memoizationTable + * @return + */ + public static Pair> GetAlignmentProbabilityDebug(String word1, String word2, int maxSubstringLength, HashMap probs, double floorProb, HashMap>> memoizationTable) { + List productions = new ArrayList<>(); + Production bestPair = new Production(null, null); + + if (word1.length() == 0 && word2.length() == 0) return new Pair<>(1.0, productions); + if (word1.length() * maxSubstringLength < word2.length()) return new Pair<>(0.0, productions); //no alignment possible + if (word2.length() * maxSubstringLength < word1.length()) return new Pair<>(0.0, productions); + + Pair> cached; + if (memoizationTable.containsKey(new Production(word1, word2))) { + cached = memoizationTable.get(new Production(word1, word2)); + productions = cached.getSecond(); + return new Pair<>(cached.getFirst(), productions); + } + + double maxProb = 0; + + int maxSubstringLength1 = Math.min(word1.length(), maxSubstringLength); + int maxSubstringLength2 = Math.min(word2.length(), maxSubstringLength); + + for (int i = 1; i <= maxSubstringLength1; i++) { + String substring1 = word1.substring(0, i); + for (int j = 0; j <= maxSubstringLength2; j++) { + double localProb; + if (probs.containsKey(new Production(substring1, word2.substring(0, j)))) { + localProb = probs.get(new Production(substring1, word2.substring(0, j))); + //double localProb = ((double)count) / totals[substring1]; + if (localProb < maxProb || localProb < floorProb) + continue; //this is a really bad transition--discard + + List outProductions; + Pair> ret = GetAlignmentProbabilityDebug(word1.substring(i), word2.substring(j), maxSubstringLength, probs, maxProb / localProb, memoizationTable); + outProductions = ret.getSecond(); + + localProb *= ret.getFirst(); + if (localProb > maxProb) { + productions = outProductions; + maxProb = localProb; + bestPair = new Production(substring1, word2.substring(0, j)); + } + } + } + } + + productions = new ArrayList<>(productions); //clone it before modifying + productions.add(0, bestPair); + + memoizationTable.put(new Production(word1, word2), new Pair<>(maxProb, productions)); + + return new Pair<>(maxProb, productions); + } + + /** + * This is the same as probs, only in a more convenient format. Does not include weights on productions. + * @param probs production probabilities + * @return hashmap mapping from Production[0] => Production[1] + */ + public static HashMap> GetProbMap(HashMap probs) { + HashMap> result = new HashMap<>(); + for (Production pair : probs.keySet()) { + if(!result.containsKey(pair.getFirst())){ + result.put(pair.getFirst(), new HashSet()); + } + HashSet set = result.get(pair.getFirst()); + set.add(pair.getSecond()); + + result.put(pair.getFirst(), set); + } + + return result; + } + + + /** + * Given a word, ngramProbs, and an ngramSize, this predicts the probability of the word with respect to this language model. + * + * Compare this against: + * + * @param word + * @param ngramProbs + * @param ngramSize + * @return + */ + public static double GetLanguageProbability(String word, HashMap ngramProbs, int ngramSize) { + double probability = 1; + String paddedExample = StringUtils.repeat('_', ngramSize - 1) + word + StringUtils.repeat('_', ngramSize - 1); + for (int i = ngramSize - 1; i < paddedExample.length(); i++) { + int n = ngramSize; + //while (!ngramProbs.TryGetValue(paddedExample.substring(i - n + 1, n), out localProb)) { + + // This is a backoff procedure. + String ss = paddedExample.substring(i-n+1, i+1); + Double localProb = ngramProbs.get(ss); + while(localProb == null){ + n--; + if (n == 1) + return 0; // final backoff probability. Be careful with this... can result in names with 0 probability if the LM isn't large enough. + ss = paddedExample.substring(i-n+1, i+1); + localProb = ngramProbs.get(ss); + } + probability *= localProb; + } + + return probability; + } + + + /** + * This is used in the generation process. + * @param topK number of candidates to return + * @param word1 + * @param maxSubstringLength + * @param probMap a hashmap for productions, same as probs, but with no weights + * @param probs + * @param memoizationTable + * @param pruneToSize + * @return + */ + public static TopList Predict2(int topK, String word1, int maxSubstringLength, Map> probMap, HashMap probs, HashMap> memoizationTable, int pruneToSize) { + TopList result = new TopList<>(topK); + // calls a helper function + HashMap rProbs = Predict2(word1, maxSubstringLength, probMap, probs, memoizationTable, pruneToSize); + double probSum = 0; + + // gathers the total probability for normalization. + for (double prob : rProbs.values()) + probSum += prob; + + // this normalizes each value by the total prob. + for (String key : rProbs.keySet()) { + Double value = rProbs.get(key); + result.add(new Pair<>(value / probSum, key)); + } + + return result; + } + + /** + * Helper function. + * @param word1 + * @param maxSubstringLength + * @param probMap + * @param probs + * @param memoizationTable + * @param pruneToSize + * @return + */ + public static HashMap Predict2(String word1, int maxSubstringLength, Map> probMap, HashMap probs, HashMap> memoizationTable, int pruneToSize) { + HashMap result; + if (word1.length() == 0) { + result = new HashMap<>(1); + result.put("", 1.0); + return result; + } + + if (memoizationTable.containsKey(word1)) { + return memoizationTable.get(word1); + } + + result = new HashMap<>(); + + int maxSubstringLength1 = Math.min(word1.length(), maxSubstringLength); + + for (int i = 1; i <= maxSubstringLength1; i++) { + String substring1 = word1.substring(0, i); + + if (probMap.containsKey(substring1)) { + + // recursion right here. + HashMap appends = Predict2(word1.substring(i), maxSubstringLength, probMap, probs, memoizationTable, pruneToSize); + + //int segmentations = Segmentations( word1.Length - i ); + + for (String tgt : probMap.get(substring1)) { + Production alignment = new Production(substring1, tgt); + + double alignmentProb = probs.get(alignment); + + for (String key : appends.keySet()) { + Double value = appends.get(key); + String word = alignment.getSecond() + key; + //double combinedProb = (pair.Value/segmentations) * alignmentProb; + double combinedProb = (value) * alignmentProb; + + // I hope this is an accurate translation... + Dictionaries.IncrementOrSet(result, word, combinedProb, combinedProb); + } + } + + } + } + + if (result.size() > pruneToSize) { + Double[] valuesArray = result.values().toArray(new Double[result.values().size()]); + String[] data = result.keySet().toArray(new String[result.size()]); + + //Array.Sort (valuesArray, data); + + TreeMap sorted = new TreeMap<>(); + for(int i = 0 ; i < valuesArray.length; i++){ + sorted.put(valuesArray[i], data[i]); + } + + // FIXME: is this sorted in the correct order??? + + //double sum = 0; + //for (int i = data.Length - pruneToSize; i < data.Length; i++) + // sum += valuesArray[i]; + + result = new HashMap<>(pruneToSize); +// for (int i = data.length - pruneToSize; i < data.length; i++) +// result.put(data[i], valuesArray[i]); + + int i = 0; + for(Double d : sorted.descendingKeySet()){ + result.put(sorted.get(d), d); + if (i++ > pruneToSize){ + break; + } + } + } + + memoizationTable.put(word1, result); + return result; + } + + /** + * + * This makes sure to pad front and back of the string, if pad is True. + * + * @param n size of ngram + * @param examples list of examples + * @param pad whether or not we should use padding. + * @return + */ + public static HashMap GetNgramCounts(int n, Iterable examples, boolean pad) { + HashMap result = new HashMap<>(); + for (String example : examples) { + String padstring = StringUtils.repeat("_", n-1); + String paddedExample = (pad ? padstring + example + padstring : example); + + for (int i = 0; i <= paddedExample.length() - n; i++) { + //System.out.println(i + ": " + n); + Dictionaries.IncrementOrSet(result, paddedExample.substring(i, i+n), 1, 1); + } + } + + return result; + } + + public static HashMap GetFixedSizeNgramProbs(int n, Iterable examples) { + HashMap ngramCounts = GetNgramCounts(n, examples, true); + HashMap ngramTotals = new HashMap<>(); + for (String key : ngramCounts.keySet()) { + int v = ngramCounts.get(key); + Dictionaries.IncrementOrSet(ngramTotals, key.substring(0, n - 1), v, v); + } + + HashMap result = new HashMap<>(ngramCounts.size()); + for (String key : ngramCounts.keySet()) { + int v = ngramCounts.get(key); + result.put(key, ((double) v) / ngramTotals.get(key.substring(0, n - 1))); + } + + return result; + } + + public static HashMap GetNgramProbs(int minN, int maxN, Iterable examples) { + HashMap result = new HashMap<>(); + + for (int i = minN; i <= maxN; i++) { + HashMap map = GetFixedSizeNgramProbs(i, examples); + for (String key : map.keySet()) { + double value = map.get(key); + result.put(key, value); + } + } + + return result; + } + + /** + * This function loops over all sizes of ngrams, from minN to maxN, and creates + * an ngram model, and also normalizes it. + * + * @param minN minimum size ngram + * @param maxN maximum size ngram + * @param examples list of examples + * @param padding whether or not this should be padded + * @return a hashmap of ngrams. + */ + public static HashMap GetNgramCounts(int minN, int maxN, Iterable examples, boolean padding) { + HashMap result = new HashMap<>(); + for (int i = minN; i <= maxN; i++) { + HashMap counts = GetNgramCounts(i, examples, padding); + int total = 0; + for (int v : counts.values()) { + total += v; + } + + for (String key : counts.keySet()) { + int value = counts.get(key); + result.put(key, ((double) value) / total); + } + } + + return result; + } + + + + /** + * Given a map of productions and corresponding counts, get the counts of the source word in each + * production. + * @param counts production counts + * @return a map from source strings to counts. + */ + public static HashMap GetAlignmentTotals1(HashMap counts) { + // the string in this map is the source string. + HashMap result = new HashMap<>(); + for (Production key : counts.keySet()) { + Double value = counts.get(key); + + String source = key.getFirst(); + + // Increment or set + if(result.containsKey(source)){ + result.put(source, result.get(source) + value); + }else{ + result.put(source, value); + } + } + + return result; + } + + /** + * This finds all possible alignments between word1 and word2. + * @param word1 + * @param word2 + * @param maxSubstringLength1 + * @param maxSubstringLength2 + * @param internTable + * @param normalization + * @return + */ + public static HashMap FindAlignments(String word1, String word2, int maxSubstringLength1, int maxSubstringLength2, InternDictionary internTable, NormalizationMode normalization) { + HashMap alignments = new HashMap<>(); + + // this populates the alignments hashmap. + // FIXME: why not assign to alignments here? + // FIXME: why is it boolean? Is the value ever false? What does it mean? + HashSet memoizationtable = new HashSet<>(); + FindAlignments(word1, word2, maxSubstringLength1, maxSubstringLength2, alignments, memoizationtable); + + // FIXME: probably don't need this? What about interning?? + HashMap result = new HashMap<>(alignments.size()); + for (Production key : alignments.keySet()) { + result.put(new Production(internTable.Intern(key.getFirst()), internTable.Intern(key.getSecond())), 1.0); + } + + return Normalize(word1, word2, result, internTable, normalization); + } + + /** + * This does no normalization, but interns the string in each production. + * @param counts + * @param internTable + * @return + */ + public static HashMap InternProductions(HashMap counts, InternDictionary internTable) { + HashMap result = new HashMap<>(counts.size()); + + for (Production key : counts.keySet()) { + Double value = counts.get(key); + result.put(new Production(internTable.Intern(key.getFirst()), internTable.Intern(key.getSecond()), key.getOrigin()), value); + } + + return result; + } + + /** + * This normalizes the raw counts by the counts of the source strings in each production. + * Example: + * Raw counts={ Prod("John", "Yon")=>4, Prod("John","Jon")=>1 } + * Normalized={ Prod("John", "Yon")=>4/5, Prod("John","Jon")=>1/5} + * + * If source strings only ever show up for one target string, then this does nothing. + * + * @param counts raw counts + * @param internTable + * @return + */ + public static HashMap NormalizeBySourceSubstring(HashMap counts, InternDictionary internTable) { + // gets counts by source strings + HashMap totals = GetAlignmentTotals1(counts); + + HashMap result = new HashMap<>(counts.size()); + + for (Production key : counts.keySet()) { + Double value = counts.get(key); + result.put(new Production(internTable.Intern(key.getFirst()), internTable.Intern(key.getSecond())), value / totals.get(key.getFirst())); + } + + return result; + } + + public static HashMap GetSourceSubstringMax(HashMap, Double> counts) { + HashMap result = new HashMap<>(counts.size()); + for (Pair key : counts.keySet()) { + Double value = counts.get(key); + if (result.containsKey(key.getFirst())) + result.put(key.getFirst(), Math.max(value, result.get(key.getFirst()))); + else + result.put(key.getFirst(), value); + } + + return result; + } + + public static HashMap Normalize(String sourceWord, String targetWord, HashMap counts, InternDictionary internTable, NormalizationMode normalization) { + if (normalization == NormalizationMode.BySourceSubstring) + return NormalizeBySourceSubstring(counts, internTable); +// else if (normalization == NormalizationMode.AllProductions) +// return NormalizeAllProductions(counts, internTable); +// else if (normalization == NormalizationMode.BySourceSubstringMax) +// return NormalizeBySourceSubstringMax(counts, internTable); +// else if (normalization == NormalizationMode.BySourceAndTargetSubstring) +// return NormalizeBySourceAndTargetSubstring(counts, internTable); +// else if (normalization == NormalizationMode.BySourceOverlap) +// return NormalizeBySourceOverlap(sourceWord, counts, internTable); +// else if (normalization == NormalizationMode.ByTargetSubstring) +// return NormalizeByTargetSubstring(counts, internTable); + else + return InternProductions(counts, internTable); + } + + public static HashMap FindWeightedAlignments(String word1, String word2, int maxSubstringLength1, int maxSubstringLength2, HashMap probs, InternDictionary internTable, NormalizationMode normalization) { + HashMap weights = new HashMap<>(); + FindWeightedAlignments(1, new ArrayList(), word1, word2, maxSubstringLength1, maxSubstringLength2, probs, weights, new HashMap>()); + + //CheckDictionary(weights); + + HashMap weights2 = new HashMap<>(weights.size()); + for (Production wkey : weights.keySet()) { + weights2.put(wkey, weights.get(wkey) == 0 ? 0 : weights.get(wkey) / probs.get(wkey)); + } + //weights2[wPair.Key] = weights[wPair.Key] == 0 ? 0 : Math.Pow(weights[wPair.Key], 1d / word1.Length); + weights = weights2; + + return Normalize(word1, word2, weights, internTable, normalization); + } + + public static HashMap FindWeightedAlignmentsAverage(String word1, String word2, int maxSubstringLength1, int maxSubstringLength2, HashMap probs, InternDictionary internTable, Boolean weightByOthers, NormalizationMode normalization) { + HashMap weights = new HashMap<>(); + HashMap weightCounts = new HashMap<>(); + //FindWeightedAlignmentsAverage(1, new List>(), word1, word2, maxSubstringLength1, maxSubstringLength2, probs, weights, weightCounts, new HashMap, Pair>(), weightByOthers); + FindWeightedAlignmentsAverage(1, new ArrayList(), word1, word2, maxSubstringLength1, maxSubstringLength2, probs, weights, weightCounts, weightByOthers); + + //CheckDictionary(weights); + + HashMap weights2 = new HashMap<>(weights.size()); + for (Production wkey : weights.keySet()) + weights2.put(wkey, weights.get(wkey) == 0 ? 0 : weights.get(wkey) / weightCounts.get(wkey)); + weights = weights2; + + return Normalize(word1, word2, weights, internTable, normalization); + } + + public static double FindWeightedAlignments(double probability, List productions, String word1, String word2, int maxSubstringLength1, int maxSubstringLength2, HashMap probs, HashMap weights, HashMap> memoizationTable) { + if (word1.length() == 0 && word2.length() == 0) //record probabilities + { + for (Production production : productions) { + if(weights.containsKey(production) && weights.get(production) > probability){ + continue; + }else{ + weights.put(production, probability); + } + } + return 1; + } + + //Check memoization table to see if we can return early + Pair probPair; + + if(memoizationTable.containsKey(new Production(word1, word2))){ + probPair = memoizationTable.get(new Production(word1, word2)); + if (probPair.getFirst() >= probability) //we ran against these words with a higher probability before; + { + probability *= probPair.getSecond(); //get entire production sequence probability + + for (Production production : productions) { + if(weights.containsKey(production) && weights.get(production) > probability){ + continue; + }else{ + weights.put(production, probability); + } + } + + return probPair.getSecond(); + } + } + + int maxSubstringLength1f = Math.min(word1.length(), maxSubstringLength1); + int maxSubstringLength2f = Math.min(word2.length(), maxSubstringLength2); + + double bestProb = 0; + + for (int i = 1; i <= maxSubstringLength1f; i++) //for each possible substring in the first word... + { + String substring1 = word1.substring(0, i); + + for (int j = 1; j <= maxSubstringLength2f; j++) //for possible substring in the second + { + if ((word1.length() - i) * maxSubstringLength2 >= word2.length() - j && (word2.length() - j) * maxSubstringLength1 >= word1.length() - i) //if we get rid of these characters, can we still cover the remainder of word2? + { + String substring2 = word2.substring(0, j); + Production production = new Production(substring1, substring2); + double prob = probs.get(production); + + productions.add(production); + double thisProb = prob * FindWeightedAlignments(probability * prob, productions, word1.substring(i), word2.substring(j), maxSubstringLength1, maxSubstringLength2, probs, weights, memoizationTable); + productions.remove(productions.size() - 1); + + if (thisProb > bestProb) bestProb = thisProb; + } + } + } + + memoizationTable.put(new Production(word1, word2), new Pair<>(probability, bestProb)); + return bestProb; + } + + public static double FindWeightedAlignmentsAverage(double probability, List productions, String word1, String word2, int maxSubstringLength1, int maxSubstringLength2, HashMap probs, HashMap weights, HashMap weightCounts, Boolean weightByOthers) { + if (probability == 0) return 0; + + if (word1.length() == 0 && word2.length() == 0) //record probabilities + { + for (Production production : productions) { + double probValue = weightByOthers ? probability / probs.get(production) : probability; + //weight the contribution to the average by its probability (square it) + Dictionaries.IncrementOrSet(weights, production, probValue * probValue, probValue * probValue); + Dictionaries.IncrementOrSet(weightCounts, production, probValue, probValue); + } + return 1; + } + + int maxSubstringLength1f = Math.min(word1.length(), maxSubstringLength1); + int maxSubstringLength2f = Math.min(word2.length(), maxSubstringLength2); + + double bestProb = 0; + + for (int i = 1; i <= maxSubstringLength1f; i++) //for each possible substring in the first word... + { + String substring1 = word1.substring(0, i); + + for (int j = 1; j <= maxSubstringLength2f; j++) //for possible substring in the second + { + if ((word1.length() - i) * maxSubstringLength2 >= word2.length() - j && (word2.length() - j) * maxSubstringLength1 >= word1.length() - i) //if we get rid of these characters, can we still cover the remainder of word2? + { + String substring2 = word2.substring(0, j); + Production production = new Production(substring1, substring2); + double prob = probs.get(production); + + productions.add(production); + double thisProb = prob * FindWeightedAlignmentsAverage(probability * prob, productions, word1.substring(i), word2.substring(j), maxSubstringLength1, maxSubstringLength2, probs, weights, weightCounts, weightByOthers); + productions.remove(productions.size() - 1); + + if (thisProb > bestProb) bestProb = thisProb; + } + } + } + + //memoizationTable[new Pair(word1, word2)] = new Pair(probability, bestProb); + return bestProb; + } + + + /** + * Finds the single best alignment for the two words and uses that to increment the counts. + * WeighByProbability does not use the real, noramalized probability, but rather a proportional probability + * and is thus not "theoretically valid". + * @param word1 + * @param word2 + * @param maxSubstringLength + * @param probs + * @param internTable + * @param weighByProbability + * @return + */ + public static HashMap CountMaxAlignments(String word1, String word2, int maxSubstringLength, HashMap probs, InternDictionary internTable, Boolean weighByProbability) { + + Pair> result1 = GetAlignmentProbabilityDebug(word1, word2, maxSubstringLength, probs); + double prob = result1.getFirst(); + List productions = result1.getSecond(); + //CheckDictionary(weights); + + HashMap result = new HashMap<>(productions.size()); + + if (prob == 0) //no possible alignment for some reason + { + return result; //nothing learned //result.Add(new Pair(internTable.Intern(word1),internTable.Intern(word2), + } + + for (Production production : productions) { + Dictionaries.IncrementOrSet(result, new Production(internTable.Intern(production.getFirst()), internTable.Intern(production.getSecond())), weighByProbability ? prob : 1, weighByProbability ? prob : 1); + } + + + return result; + } + + /** + * What does this do? Largely calls CountWeightedAlignmentsHelper + * @param word1 + * @param word2 + * @param maxSubstringLength1 + * @param maxSubstringLength2 + * @param probs + * @param internTable + * @param normalization + * @param weightByContextOnly + * @return + */ + public static HashMap CountWeightedAlignments(String word1, String word2, int maxSubstringLength1, int maxSubstringLength2, HashMap probs, InternDictionary internTable, NormalizationMode normalization, Boolean weightByContextOnly) { + //HashMap, double> weights = new HashMap, double>(); + //HashMap, double> weightCounts = new HashMap, double>(); + //FindWeightedAlignmentsAverage(1, new List>(), word1, word2, maxSubstringLength1, maxSubstringLength2, probs, weights, weightCounts, new HashMap, Pair>(), weightByOthers); + Pair, Double> Q = CountWeightedAlignmentsHelper(word1, word2, maxSubstringLength1, maxSubstringLength2, probs, new HashMap, Double>>()); + HashMap weights = Q.getFirst(); + double probSum = Q.getSecond(); //the sum of the probabilities of all possible alignments + + // this is where the 1/y normalization happens for this word pair. + HashMap weights_norm = new HashMap<>(weights.size()); + for (Production key : weights.keySet()) { + Double value = weights.get(key); + if (weightByContextOnly) { + double originalProb = probs.get(key); + weights_norm.put(key, value == 0 ? 0 : (value / originalProb) / (probSum - value + (value / originalProb))); + } else + weights_norm.put(key, value == 0 ? 0 : value / probSum); + } + + return Normalize(word1, word2, weights_norm, internTable, normalization); + } + + /** + * Gets counts for productions by (conceptually) summing over all the possible alignments + * and weighing each alignment (and its constituent productions) by the given probability table. + * probSum is important (and memoized for input word pairs)--it keeps track and returns the sum of the + * probabilities of all possible alignments for the word pair + * + * This is Algorithm 3 in the paper. + * + * @param word1 + * @param word2 + * @param maxSubstringLength1 + * @param maxSubstringLength2 + * @param probs + * @param memoizationTable + * @return a hashmap and double as a pair. The double is y, a normalization constant. The hashmap is a table of substring pairs + * and their unnormalized counts + */ + public static Pair, Double> CountWeightedAlignmentsHelper(String word1, String word2, int maxSubstringLength1, int maxSubstringLength2, HashMap probs, HashMap, Double>> memoizationTable) { + double probSum; + + Pair, Double> memoization; + for(int orig = 0; orig < SPModel.numOrigins; orig++) { + if (memoizationTable.containsKey(new Production(word1, word2, orig))) { + memoization = memoizationTable.get(new Production(word1, word2, orig)); + probSum = memoization.getSecond(); //stored probSum + return new Pair<>(memoization.getFirst(), probSum); //table of probs + } + } + + HashMap result = new HashMap<>(); // this is C in Algorithm 3 in the paper + probSum = 0; // this is R in Algorithm 3 in the paper + + if (word1.length() == 0 && word2.length() == 0) //record probabilities + { + probSum = 1; //null -> null is always a perfect alignment + return new Pair<>(result,probSum); //end of the line + } + + int maxSubstringLength1f = Math.min(word1.length(), maxSubstringLength1); + int maxSubstringLength2f = Math.min(word2.length(), maxSubstringLength2); + + for(int orig = 0; orig < SPModel.numOrigins; orig++) { + for (int i = 1; i <= maxSubstringLength1f; i++) //for each possible substring in the first word... + { + String substring1 = word1.substring(0, i); + + for (int j = 1; j <= maxSubstringLength2f; j++) //for possible substring in the second + { + if ((word1.length() - i) * maxSubstringLength2 >= word2.length() - j && (word2.length() - j) * maxSubstringLength1 >= word1.length() - i) //if we get rid of these characters, can we still cover the remainder of word2? + { + String substring2 = word2.substring(0, j); + + Production production = new Production(substring1, substring2, orig); + double prob = probs.get(production); + + // recurse here. Result is Q in Algorithm 3 + Pair, Double> Q = CountWeightedAlignmentsHelper(word1.substring(i), word2.substring(j), maxSubstringLength1, maxSubstringLength2, probs, memoizationTable); + + HashMap remainderCounts = Q.getFirst(); + Double remainderProbSum = Q.getSecond(); + + Dictionaries.IncrementOrSet(result, production, prob * remainderProbSum, prob * remainderProbSum); + + //update our probSum + probSum += remainderProbSum * prob; + + //update all the productions that come later to take into account their preceding production's probability + for (Production key : remainderCounts.keySet()) { + Double value = remainderCounts.get(key); + Dictionaries.IncrementOrSet(result, key, prob * value, prob * value); + } + } + } + } + } + + for(int orig = 0; orig < SPModel.numOrigins; orig++) { + memoizationTable.put(new Production(word1, word2, orig), new Pair<>(result, probSum)); + } + return new Pair<>(result, probSum); + } + + public static String[] GetLeftFallbackContexts(String word, int position, int contextSize) { + String[] result = new String[contextSize + 1]; + for (int i = 0; i < result.length; i++) + result[i] = word.substring(position - i, position); + + return result; + } + + public static String[] GetRightFallbackContexts(String word, int position, int contextSize) { + String[] result = new String[contextSize + 1]; + for (int i = 0; i < result.length; i++) + result[i] = word.substring(position, position+i); + + return result; + } + + public static String GetLeftContext(String word, int position, int contextSize) { + return word.substring(position - contextSize, position); + } + + public static String GetRightContext(String word, int position, int contextSize) { + return word.substring(position, position+contextSize); + } + + /** + * Finds the probability of word1 transliterating to word2 over all possible alignments + * This is Algorithm 1 in the paper. + * @param word1 Source word + * @param word2 Transliterated word + * @param maxSubstringLength1 constant field from SPModel + * @param maxSubstringLength2 constant field from SPModel + * @param probs map from production to weight?? + * @param memoizationTable + * @param minProductionProbability + * @return + */ + public static double GetSummedAlignmentProbability(String word1, String word2, int maxSubstringLength1, int maxSubstringLength2, HashMap probs, HashMap memoizationTable, double minProductionProbability, int origin) { + + if(memoizationTable.containsKey(new Production(word1, word2, origin))){ + return memoizationTable.get(new Production(word1, word2, origin)); + } + + if (word1.length() == 0 && word2.length() == 0) //record probabilities + return 1; //null -> null is always a perfect alignment + + double probSum = 0; + + int maxSubstringLength1f = Math.min(word1.length(), maxSubstringLength1); + int maxSubstringLength2f = Math.min(word2.length(), maxSubstringLength2); + + double localMinProdProb = 1; + for (int i = 1; i <= maxSubstringLength1f; i++) //for each possible substring in the first word... + { + localMinProdProb *= minProductionProbability; + + String substring1 = word1.substring(0, i); + + for (int j = 1; j <= maxSubstringLength2f; j++) //for possible substring in the second + { + //if we get rid of these characters, can we still cover the remainder of word2? + if ((word1.length() - i) * maxSubstringLength2 >= word2.length() - j && (word2.length() - j) * maxSubstringLength1 >= word1.length() - i) + { + String substring2 = word2.substring(0, j); + Production production = new Production(substring1, substring2, origin); + + double prob = 0; + + if(!probs.containsKey(production)){ + if (localMinProdProb == 0){ + continue; + } + }else{ + prob = probs.get(production); + } + + prob = Math.max(prob, localMinProdProb); + + double remainderProbSum = GetSummedAlignmentProbability(word1.substring(i), word2.substring(j), maxSubstringLength1, maxSubstringLength2, probs, memoizationTable, minProductionProbability, origin); + + //update our probSum + probSum += remainderProbSum * prob; + } + } + } + + memoizationTable.put(new Production(word1, word2), probSum); + return probSum; + } + + /** + * This recursively finds all possible alignments between word1 and word2 and populates the alignments hashmap with them. + * + * @param word1 word or substring of a word + * @param word2 word or substring of a word + * @param maxSubstringLength1 + * @param maxSubstringLength2 + * @param alignments this is the result + * @param memoizationTable + */ + public static void FindAlignments(String word1, String word2, int maxSubstringLength1, int maxSubstringLength2, HashMap alignments, HashSet memoizationTable) { + if (memoizationTable.contains(new Production(word1, word2))) + return; //done + + int maxSubstringLength1f = Math.min(word1.length(), maxSubstringLength1); + int maxSubstringLength2f = Math.min(word2.length(), maxSubstringLength2); + + for (int i = 1; i <= maxSubstringLength1f; i++) //for each possible substring in the first word... + { + String substring1 = word1.substring(0, i); + + for (int j = 1; j <= maxSubstringLength2f; j++) //for possible substring in the second + { + //if we get rid of these characters, can we still cover the remainder of word2? + if ((word1.length() - i) * maxSubstringLength2 >= word2.length() - j && (word2.length() - j) * maxSubstringLength1 >= word1.length() - i) + { + alignments.put(new Production(substring1, word2.substring(0, j)), 1.0); + FindAlignments(word1.substring(i), word2.substring(j), maxSubstringLength1, maxSubstringLength2, alignments, memoizationTable); + } + } + } + + memoizationTable.add(new Production(word1, word2)); + } +} diff --git a/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/WordAlignment.java b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/WordAlignment.java new file mode 100644 index 000000000..dbb1e074f --- /dev/null +++ b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/WordAlignment.java @@ -0,0 +1,67 @@ +/** + * This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computation Group University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.transliteration; + +import java.io.Serializable; + +class WordAlignment implements Serializable { + /// + /// Each String had exactly one word. + /// + public int oneToOne; + + /// + /// There was an equal number of more than one words in each String. + /// + public int equalNumber; + + /// + /// There were more words in one String than the other. + /// + public int unequalNumber; + + public WordAlignment(int oneToOne, int equalNumber, int unequalNumber) { + this.oneToOne = oneToOne; + this.equalNumber = equalNumber; + this.unequalNumber = unequalNumber; + } + + @Override + public String toString() { + return oneToOne + ":" + equalNumber + ":" + unequalNumber; + } + + public WordAlignment(String wordAlignmentString) { + String[] values = wordAlignmentString.split(":"); + oneToOne = Integer.parseInt(values[0]); + equalNumber = Integer.parseInt(values[1]); + unequalNumber = Integer.parseInt(values[2]); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + WordAlignment that = (WordAlignment) o; + + if (oneToOne != that.oneToOne) return false; + if (equalNumber != that.equalNumber) return false; + return unequalNumber == that.unequalNumber; + + } + + @Override + public int hashCode() { + int result = oneToOne; + result = 31 * result + equalNumber; + result = 31 * result + unequalNumber; + return result; + } +} + diff --git a/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/WordCompression.java b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/WordCompression.java new file mode 100644 index 000000000..a6d383c6b --- /dev/null +++ b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/WordCompression.java @@ -0,0 +1,187 @@ +/** + * This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computation Group University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.transliteration; + +import edu.illinois.cs.cogcomp.utils.InternDictionary; +import edu.illinois.cs.cogcomp.utils.SparseDoubleVector; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; + +class WordCompression { + + public static List GetNgrams(String example) { + HashMap ngramList = new HashMap<>(); + for (int n = 1; n <= example.length(); n++) + for (int i = 0; i <= example.length() - n; i++) { + ngramList.put(example.substring(i, n), true); + } + + return new ArrayList<>(ngramList.keySet()); + } + + private static int MinChunkCount(String example, HashSet chunks) { + //jon + int[] counts = new int[example.length() + 1]; + for (int i = example.length() - 1; i >= 0; i--) { + counts[i] = Integer.MAX_VALUE; + int maxLength = example.length() - i; + for (int j = 1; j <= maxLength; j++) { + if (chunks.contains(example.substring(i, j))) + counts[i] = Math.min(1 + counts[i + j], counts[i]); + } + } + + return counts[0]; + } + + private static String[] MinChunks(String example, HashSet chunks) { + //jon + int[] counts = new int[example.length() + 1]; + String[] bestString = new String[example.length()]; + for (int i = example.length() - 1; i >= 0; i--) { + counts[i] = Integer.MAX_VALUE; + int maxLength = example.length() - i; + for (int j = 1; j <= maxLength; j++) { + String ss = example.substring(i, j); + if (chunks.contains(ss)) { + int addCount = 1 + counts[i + j]; + if (addCount < counts[i]) { + counts[i] = addCount; + bestString[i] = ss; + } + } + } + } + + if (bestString[0] == null) return null; + + String[] result = new String[counts[0]]; + + int nextOffset = 0; + for (int i = 0; i < result.length; i++) { + result[i] = bestString[nextOffset]; + nextOffset += result[i].length(); + } + + return result; + } + +// public static void Compress(SparseDoubleVector examples) { +// InternDictionary internTable = new InternDictionary(); +// +// //build a map from substrings to words +// HashMap> substringMap = new HashMap>(); +// for (String example : examples.Keys) +// for (String ngram : GetNgrams(example)) { +// List wordList; +// if (!substringMap.TryGetValue(ngram, out wordList)) +// wordList = substringMap[ngram] = new List(); +// +// wordList.Add(example); +// } +// +// //initialize the chunk set +// HashSet chunks = new HashSet(); +// for (String ngram : substringMap.Keys) +// if (ngram.Length == 1) chunks.Add(ngram); +// +// HashMap chunksRequired = new HashMap(examples.Count); +// int totalSegments = 0; +// for (String example : examples.Keys) +// totalSegments += ((int) examples[example]) * (chunksRequired[example] = MinChunkCount(example, chunks)); +// +// int chunksLength = chunks.size(); +// +// double currentScore = totalSegments * Math.Log(chunks.size(), 2) + 8 * (chunksLength + chunks.size()); //initial score +// +// System.out.println("Initial score = " + currentScore + "; " + totalSegments + " segments; " + chunks.size() + " chunks."); +// +// int round = 0; +// +// while (true) { +// round++; +// +// double bestScore = currentScore; +// String bestMove = null; +// //Set chunksCopy = new Set(chunks); +// for (String ngram : examples.Keys) +// { +// if (chunks.contains(ngram)) { +// chunks.Remove(ngram); //try deleting it +// chunksLength -= ngram.Length; +// } else { +// chunks.Add(ngram); //try adding it +// chunksLength += ngram.Length; +// } +// +// int chunkChange = 0; +// Boolean impossible = false; +// for (String example : substringMap[ngram]) { +// int cc = MinChunkCount(example, chunks); +// if (cc == int.MaxValue) { +// impossible = true; +// break; +// } +// chunkChange += ((int) examples[example]) * (cc - chunksRequired[example]); +// } +// +// if (!impossible) { +// double newScore = (totalSegments + chunkChange) * Math.log(chunks.size(), 2) + 8 * (chunksLength + chunks.size()); +// if (newScore < bestScore) { +// bestScore = newScore; +// bestMove = ngram; +// } +// } +// +// if (chunks.contains(ngram)) { +// chunks.remove(ngram); //try deleting it +// chunksLength -= ngram.Length; +// } else { +// chunks.add(ngram); //try adding it +// chunksLength += ngram.Length; +// } +// +// +// } +// +// if (bestMove == null) { +// System.out.println("Finished. Local max found. Return to quit."); +// SaveSet(chunks, @ "C:\Data\WikiTransliteration\Segmentation\chunks.txt"); +// Console.ReadLine(); +// return; +// } else { +// System.out.println("Compressing (round " + round + "): Old score = " + currentScore + "; new score = " + bestScore); +// System.out.println("Segments per word: " + (((double) totalSegments) / examples.Count) + "; " + totalSegments + " segments; " + chunks.size() + " chunks (length = " + chunksLength + ")"); +// if (!chunks.contains(bestMove)) +// System.out.println("Adding " + bestMove); +// else System.out.println("Removing " + bestMove); +// System.out.println(); +// +// currentScore = bestScore; +// +// if (chunks.contains(bestMove)) { +// chunks.remove(bestMove); //try deleting it +// chunksLength -= bestMove.Length; +// } else { +// chunks.add(bestMove); //try adding it +// chunksLength += bestMove.Length; +// } +// +// for (String example : substringMap[bestMove]) { +// int cc = MinChunkCount(example, chunks); +// totalSegments += ((int) examples[example]) * (cc - chunksRequired[example]); +// chunksRequired[example] = cc; +// } +// } +// } +// } +} + diff --git a/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/WordSegmentation.java b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/WordSegmentation.java new file mode 100644 index 000000000..c07223e0f --- /dev/null +++ b/transliteration/src/main/java/edu/illinois/cs/cogcomp/transliteration/WordSegmentation.java @@ -0,0 +1,17 @@ +/** + * This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computation Group University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.transliteration; + + +import edu.illinois.cs.cogcomp.utils.SparseDoubleVector; +import org.apache.commons.lang.StringUtils; + +import java.util.HashMap; + +class WordSegmentation { +} diff --git a/transliteration/src/main/java/edu/illinois/cs/cogcomp/utils/Dictionaries.java b/transliteration/src/main/java/edu/illinois/cs/cogcomp/utils/Dictionaries.java new file mode 100644 index 000000000..adb5a73dc --- /dev/null +++ b/transliteration/src/main/java/edu/illinois/cs/cogcomp/utils/Dictionaries.java @@ -0,0 +1,92 @@ +/** + * This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computation Group University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.utils; + +import edu.illinois.cs.cogcomp.transliteration.Production; + +import java.util.HashMap; + +/** + * Created by stephen on 10/8/15. + */ +public class Dictionaries { + + public static void IncrementOrSet(HashMap m, K p, int incrementValue, int setValue){ + if(m.containsKey(p)){ + m.put(p, m.get(p) + incrementValue); + }else{ + m.put(p, setValue); + } + } + + public static void IncrementOrSet(HashMap m, K p, double incrementValue, double setValue){ + if(m.containsKey(p)){ + m.put(p, m.get(p) + incrementValue); + }else{ + m.put(p, setValue); + } + } + + public static void AddTo(HashMap vector, HashMap values, int valuesCoefficient) { + for(K k : values.keySet()){ + int v = valuesCoefficient*values.get(k); + IncrementOrSet(vector, k, v, v); + } + } + + /** + * Given two dictionaries, this adds the first to the second, multiplying each value in the second dictionary by valuesCoefficient. + * @param vector + * @param values + * @param valuesCoefficient + */ + public static void AddTo(HashMap vector, HashMap values, double valuesCoefficient) { + for(K k : values.keySet()){ + double v = valuesCoefficient*values.get(k); + IncrementOrSet(vector, k, v, v); + } + } + + /** + * Treats the two dictionaries as sparse vectors and multiplies their values on matching keys. Things without a key are treated as having value 0. + * @param vector1 + * @param vector2 + * @return + */ + public static HashMap MultiplyInt(HashMap vector1, HashMap vector2) { + HashMap ret = new HashMap<>(); + + for(K p : vector1.keySet()){ + if(vector2.containsKey(p)){ + ret.put(p, vector1.get(p) * vector2.get(p)); + } + } + + return ret; + + } + + /** + * Treats the two dictionaries as sparse vectors and multiplies their values on matching keys. Things without a key are treated as having value 0. + * @param vector1 + * @param vector2 + * @return + */ + public static HashMap MultiplyDouble(HashMap vector1, HashMap vector2) { + HashMap ret = new HashMap<>(); + + for(K p : vector1.keySet()){ + if(vector2.containsKey(p)){ + ret.put(p, vector1.get(p) * vector2.get(p)); + } + } + + return ret; + + } +} diff --git a/transliteration/src/main/java/edu/illinois/cs/cogcomp/utils/InternDictionary.java b/transliteration/src/main/java/edu/illinois/cs/cogcomp/utils/InternDictionary.java new file mode 100644 index 000000000..cc8ef3297 --- /dev/null +++ b/transliteration/src/main/java/edu/illinois/cs/cogcomp/utils/InternDictionary.java @@ -0,0 +1,41 @@ +/** + * This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computation Group University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.utils; + +import org.apache.commons.lang3.NotImplementedException; + +import java.util.HashMap; + +/** + * Created by stephen on 9/30/15. + */ +public class InternDictionary extends HashMap { + + public InternDictionary(){ + // don't think I need anything here... + } + + /** + * The reason this works is that containsKey uses the .equals method, which actually compares contents. This way, + * we can save memory because fewer strings are actually stored, but also avoid some nasty bugs. That is, we may + * have two strings with different locations in memory, but the same contents (dynamically created strings). If + * the code uses a == operator, it will return false. This way, all strings with the same contents will be the same + * location in memory. + * @param obj + * @return + */ + public T Intern(T obj){ + + if(!this.containsKey(obj)) { + this.put(obj, obj); + } + + return this.get(obj); + + } +} diff --git a/transliteration/src/main/java/edu/illinois/cs/cogcomp/utils/SparseDoubleVector.java b/transliteration/src/main/java/edu/illinois/cs/cogcomp/utils/SparseDoubleVector.java new file mode 100644 index 000000000..6df31c522 --- /dev/null +++ b/transliteration/src/main/java/edu/illinois/cs/cogcomp/utils/SparseDoubleVector.java @@ -0,0 +1,202 @@ +/** + * This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computation Group University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.utils; + +import org.apache.commons.lang3.NotImplementedException; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +/** + * SparseDoubleVector: basically a hashmap. + * Created by stephen on 9/24/15. + */ +public class SparseDoubleVector extends HashMap implements Serializable { + + public SparseDoubleVector(){ + super(); + } + + public SparseDoubleVector(Map dictionary) { + super(dictionary); + //this.putAll(dictionary); + } + + public SparseDoubleVector(int capacity){ + super(capacity); + } + + + /** + * SWM: created to replicate operator method + * @param value + * @return + */ + public SparseDoubleVector divide(double value){ + SparseDoubleVector ret = new SparseDoubleVector<>(); + + for(TKey t : this.keySet()){ + ret.put(t, this.get(t) / value); + } + + return ret; + } + + /** + * SWM: created to replicate operator method + * FIXME: NOT SURE IF THIS IS CORRECT. + * @param value + * @return + */ + public SparseDoubleVector divide(SparseDoubleVector value){ + SparseDoubleVector ret = new SparseDoubleVector<>(); + + for(TKey k : this.keySet()){ + double denom = 1.0; + if(value.containsKey(k)) { + denom = value.get(k); + } + ret.put(k, this.get(k) / denom); + } + + return ret; + } + + /** + * SWM: created to replicate operator method + * @param value + * @return + */ + public SparseDoubleVector multiply(double value){ + SparseDoubleVector ret = new SparseDoubleVector<>(); + + for(TKey t : this.keySet()){ + ret.put(t, this.get(t) * value); + } + + return ret; + } + + public SparseDoubleVector Abs(){ + throw new NotImplementedException("not yet implemented..."); + } + + public void put(SparseDoubleVector values){ + throw new NotImplementedException("not yet implemented..."); + } + + /** + * Add all elements of another vector to this vector, and multiply by a coefficient. + * @param coefficient + * @param values + */ + public void put(double coefficient, SparseDoubleVector values){ + for(TKey k : values.keySet()){ + this.put(k, coefficient * values.get(k)); + } + } + + + public SparseDoubleVector Ceiling(){ + throw new NotImplementedException("not yet implemented..."); + } + + public void Clear(){ + this.clear(); + } + + public boolean ContainsKey(TKey key){ + throw new NotImplementedException("not yet implemented..."); + } + + /** + * Takes the exponential of each element in the vector. + * @return + */ + public SparseDoubleVector Exp(){ + SparseDoubleVector ret = new SparseDoubleVector<>(); + + for(TKey k : this.keySet()){ + ret.put(k, Math.exp(this.get(k))); + } + return ret; + } + + // Probably won't implement this. + // public SparseDoubleVector Filter(Predicate inclusionPredicate){} + + public SparseDoubleVector Floor(){ + throw new NotImplementedException("not yet implemented..."); + } + + + public SparseDoubleVector Log(){ + throw new NotImplementedException("not yet implemented..."); + } + + public SparseDoubleVector Log(double newBase){ + throw new NotImplementedException("not yet implemented..."); + } + + public SparseDoubleVector Max(double maximum){ + throw new NotImplementedException("not yet implemented..."); + } + + public SparseDoubleVector Max(SparseDoubleVector otherVector){ + throw new NotImplementedException("not yet implemented..."); + } + + public double MaxElement(){ + throw new NotImplementedException("not yet implemented..."); + } + + public SparseDoubleVector Min(double minimum){ + throw new NotImplementedException("not yet implemented..."); + } + + public SparseDoubleVector Min(SparseDoubleVector otherVector){ + throw new NotImplementedException("not yet implemented..."); + } + + + public double MinElement(){ + throw new NotImplementedException("not yet implemented..."); + } + + /** + * Returns the p-norm of this vector. + * @param p + * @return + */ + public double PNorm(double p){ + double ret = 0; + for(TKey t : this.keySet()){ + ret += Math.pow(this.get(t), p); + } + + return Math.pow(ret, 1/p); + } + + public SparseDoubleVector Pow(double exponent){ + throw new NotImplementedException("not yet implemented..."); + } + + public boolean Remove(TKey key){ + throw new NotImplementedException("not yet implemented..."); + } + + public void RemoveRedundantElements(){ + throw new NotImplementedException("not yet implemented..."); + } + + public SparseDoubleVector Sign(){ + throw new NotImplementedException("not yet implemented..."); + } + +} diff --git a/transliteration/src/main/java/edu/illinois/cs/cogcomp/utils/TopList.java b/transliteration/src/main/java/edu/illinois/cs/cogcomp/utils/TopList.java new file mode 100644 index 000000000..f12ba3408 --- /dev/null +++ b/transliteration/src/main/java/edu/illinois/cs/cogcomp/utils/TopList.java @@ -0,0 +1,151 @@ +/** + * This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computation Group University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.utils; + +import edu.illinois.cs.cogcomp.core.datastructures.Pair; + +import java.util.*; + +/** + * This is sorted according to key. + * Created by stephen on 9/24/15. + */ +public class TopList, TValue> implements Iterable>{ + + private LinkedList> ilist; + private int topK; + + public TopList(int topK) { + ilist = new LinkedList<>(); + this.topK = topK; + } + + public int size(){ + return ilist.size(); + } + + @Override + public String toString() { + + String lst = ""; + for(Pair p : ilist){ + lst += "@@" + p.getFirst() + " :: " + p.getSecond() + "@@"; + lst += ", "; + } + + return "TopList{" + + "ilist=" + lst + + '}'; + } + + /** + * This sorts descending, according to key. + * @param other + */ + public void add(Pair other){ + if(ilist.size() == 0){ + this.ilist.add(other); + return; + } + + int addat = -1; + + for(int i =0; i < ilist.size(); i++){ + Pair mine = ilist.get(i); + + if(other.getFirst().compareTo(mine.getFirst()) > 0){ + addat = i; + break; + } + + } + + // insert it + if(addat != -1){ + this.ilist.add(addat, other); + }else{ // put it at the end if addat = -1 + this.ilist.addLast(other); + } + + if(this.ilist.size() > topK) { + this.ilist.pollLast(); + } + + } + + public void add(TKey t, TValue p){ + Pair newp =new Pair<>(t,p); + this.add(newp); + } + + public int indexOf(TValue v){ + int i = 0; + for(Pair p : this.ilist){ + if(v.equals(p.getSecond())){ + return i; + } + i++; + } + + return -1; + } + + /** + * Convert this TopList into a list. + * @return + */ + public List> toList(){ + List> out = new ArrayList<>(ilist.size()); + for(Pair p : ilist){ + out.add(p); + } + return out; + } + + /** + * Get the first value in the toplist. This is the one with the highest value. + * @return + */ + public Pair getFirst(){ + return ilist.getFirst(); + } + + + @Override + public Iterator> iterator() { + return this.ilist.iterator(); + } + + //:IList,ICollection,IEnumerable,IEnumerable + public class KeyList{ + // What goes here? + + } + + //public class ValueList:IList,ICollection,IEnumerable,IEnumerable + public class ValueList extends AbstractList implements Iterable,Collection { + + @Override + public Object get(int index) { + return null; + } + + @Override + public Iterator iterator() { + return null; + } + + @Override + public int size() { + return 0; + } + } + +} + + diff --git a/transliteration/src/main/java/edu/illinois/cs/cogcomp/utils/Utils.java b/transliteration/src/main/java/edu/illinois/cs/cogcomp/utils/Utils.java new file mode 100644 index 000000000..73bec8c17 --- /dev/null +++ b/transliteration/src/main/java/edu/illinois/cs/cogcomp/utils/Utils.java @@ -0,0 +1,526 @@ +/** + * This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computation Group University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.utils; + +import com.ibm.icu.text.Replaceable; +import com.ibm.icu.text.Transliterator; +import edu.illinois.cs.cogcomp.core.algorithms.LevensteinDistance; +import edu.illinois.cs.cogcomp.core.io.LineIO; +import edu.illinois.cs.cogcomp.core.utilities.ArrayUtilities; +import edu.illinois.cs.cogcomp.transliteration.Example; +import edu.illinois.cs.cogcomp.transliteration.MultiExample; +import edu.illinois.cs.cogcomp.transliteration.SPModel; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.w3c.dom.Document; +import org.w3c.dom.Node; +import org.w3c.dom.NodeList; +import org.xml.sax.SAXException; + +import javax.xml.parsers.DocumentBuilder; +import javax.xml.parsers.DocumentBuilderFactory; +import javax.xml.parsers.ParserConfigurationException; +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.util.*; + +/** + * Created by mayhew2 on 11/17/15. + */ +public class Utils { + + private static Logger logger = LoggerFactory.getLogger(Utils.class); + + /** + * This reads a file in the ngram-format + * http://www.speech.sri.com/projects/srilm/manpages/ngram-format.5.html + * + * and populates the languagemodel datastructure. + * + * @param fname + * @return + * @throws FileNotFoundException + */ + public static HashMap readSRILM(String fname) throws FileNotFoundException { + List lines = LineIO.read(fname); + + HashMap out = new HashMap<>(); + + for(String line : lines){ + + if(line.trim().length() == 0 || line.startsWith("\\") || line.contains("ngram") ) { + // do nothing. + }else{ + String[] sline = line.trim().split("\t"); + // important because of the log probabilities + Double v = Math.exp(Double.parseDouble(sline[0])); + String ngram = sline[1]; + + String[] chars = ngram.split(" "); + + out.put(StringUtils.join(chars, ""), v); + } + } + + return out; + } + + /** + * This measures the WAVE score of a set of productions. WAVE score comes from (Kumaran et al 2010) + * It is a measure of transliterability. + * @param fname the file name of a set of learned productions. + * @return WAVE score + */ + public static double WAVE(String fname) throws FileNotFoundException { + List lines = LineIO.read(fname); + + HashMap srcFreq = new HashMap<>(); + HashMap tgtFreq = new HashMap<>(); + + HashMap entropy = new HashMap<>(); + + for(String line : lines){ + + if(line.trim().length() == 0 || line.startsWith("#")){ + continue; + } + + String[] sline = line.split("\t"); + + String src = sline[0]; + String tgt = sline[1]; + double prob = Double.parseDouble(sline[2]); + + Dictionaries.IncrementOrSet(srcFreq, src, 1, 1); + + Dictionaries.IncrementOrSet(tgtFreq, tgt, 1, 1); + + double v = prob * Math.log(prob); + Dictionaries.IncrementOrSet(entropy, src, v, v); + + } + + double total = 0; + for(int v : srcFreq.values()){ + total += v; + } + + double WAVE = 0; + + for(String i : srcFreq.keySet()){ + // -= because entropy should be negative, but I never do it. + WAVE -= srcFreq.get(i) / total * entropy.get(i) ; + } + + return WAVE; + } + + + /** + * This is a measure used by NEWS2015. + * @param prediction + * @param referents + * @return + */ + public static double GetFuzzyF1(String prediction, List referents){ + // calculate Fuzzy F1 + String cand = prediction; + double bestld = Double.MAX_VALUE; + String bestref = ""; + for (String reference : referents) { + double ld = LevensteinDistance.getLevensteinDistance(reference, cand); + if (ld < bestld) { + bestref = reference; + bestld = ld; + } + } + + double lcs = (cand.length() + bestref.length() - bestld) * 0.5; + double R = lcs / bestref.length(); + double P = lcs / cand.length(); + double F1 = 2 * R * P / (R + P); + return F1; + } + + /** + * Helper method, this will always rearrange the data according to edit distance. + * @param file + * @return + */ + public static List readWikiData(String file) throws FileNotFoundException { + return readWikiData(file, true); + } + + /** + * This reads data in the format created by the wikipedia-api project, commonly named wikidata.Language + * @param file name of file + * @param fix whether or not the names should be reordered according to edit distance. + * @return list of examples + * @throws FileNotFoundException + */ + public static List readWikiData(String file, boolean fix) throws FileNotFoundException { + List examples = new ArrayList<>(); + List lines = LineIO.read(file); + + String id = "Any-Latin; NFD; [^\\p{Alnum}] Remove"; + //id = "Any-Latin; NFD"; + Transliterator t = Transliterator.getInstance(id); + + HashSet unique = new HashSet<>(); + + int skipping = 0; + + for(String line : lines) + { + if(line.contains("#")){ + continue; + } + + String[] parts = line.split("\t"); + + if(parts.length < 2){ + continue; + } + + + // In wikipedia data, the foreign name comes first, English second. + String foreign = parts[0].toLowerCase(); + String english = parts[1].toLowerCase(); + String[] ftoks = foreign.split(" "); + String[] etoks = english.split(" "); + + if(ftoks.length != etoks.length){ + logger.error("Mismatching length of tokens: " + english); + skipping++; + continue; + } + + // other heuristics to help clean data + if(english.contains("jr.") || english.contains("sr.") || + english.contains(" of ") || english.contains(" de ") || + english.contains("(") || english.contains("pope ")){ + skipping++; + //logger.debug("Skipping: " + english); + continue; + } + + + + int numtoks = ftoks.length; + + for(int i = 0; i < numtoks; i++){ + String ftrans = t.transform(ftoks[i]); + + int mindist = Integer.MAX_VALUE; + String bestmatch = null; + + // this is intended to help with ordering. + for(int j = 0; j < numtoks; j++){ + int d = LevensteinDistance.getLevensteinDistance(ftrans, etoks[j]); + if(d < mindist){ + // match etoks[j] with ftrans + bestmatch = etoks[j]; + mindist = d; + + // then take etoks[j] out of the running + } + } + + // strip those pesky commas. + if(ftoks[i].endsWith(",")){ + ftoks[i] = ftoks[i].substring(0,ftoks[i].length()-1); + } + + // This version uses transliterated words as the target (cheating) + //examples.add(new Example(bestmatch, ftrans)); + + Example addme; + if(fix) { + // This uses the best aligned version (recommended) + addme = new Example(bestmatch, ftoks[i]); + + }else { + // This assumes the file ordering is correct + addme = new Example(etoks[i], ftoks[i]); + } + examples.add(addme); + unique.add(addme); + } + + } + //System.out.println(file.split("\\.")[1] + " & " + numnames + " & " + examples.size() + " & " + unique.size() + " \\\\"); + logger.debug(String.format("Skipped %d lines", skipping)); + return new ArrayList<>(unique); + + } + + /** + * This reads data from the Anne Irvine, CCB paper called Transliterating from Any Language. + * @return + */ + public static List readCCBData(String srccode, String targetcode) throws FileNotFoundException { + List examples = new ArrayList<>(); + + String fname = "/shared/corpora/transliteration/from_anne_irvine/wikipedia_names"; + List lines = LineIO.read(fname); + + List key = Arrays.asList(lines.get(0).split("\t")); + int srcind = key.indexOf(srccode); + int tgtind = key.indexOf(targetcode); + + System.out.println(srcind + ", " + tgtind); + + int i = 0; + for(String line : lines) { + if (i == 0 || line.trim().length() == 0) { + i++; + continue; + } + + String[] sline = line.split("\t"); + + // Java removes whitespace at the end of a line. + if(tgtind >= sline.length){ + i++; + continue; + } + + String src = sline[srcind].trim(); + String tgt = sline[tgtind].trim(); + + if (src.length() > 0 && tgt.length() > 0) { + Example e = new Example(src, tgt); + examples.add(e); + } + + i++; + } + + return examples; + } + + + public static List convertMulti(List lme){ + List training = new ArrayList<>(); + for(MultiExample me : lme){ + for(Example e : me.toExampleList()){ + String[] tls = e.getTransliteratedWord().split(" "); + String[] ss = e.sourceWord.split(" "); + + if(tls.length != ss.length){ + logger.error("Mismatched length: " + e.sourceWord); + continue; + } + + for(int i = 0; i < tls.length; i++){ + training.add(new Example(ss[i], tls[i])); + } + } + } + return training; + } + + /** + * Used for reading data from the NEWS2015 dataset. + * @param fname + * @return + * @throws ParserConfigurationException + * @throws IOException + * @throws SAXException + */ + public static List readNEWSData(String fname) throws ParserConfigurationException, IOException, SAXException { + File file = new File(fname); + DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance(); + DocumentBuilder db = dbf.newDocumentBuilder(); + Document document = db.parse(file); + + NodeList nl = document.getElementsByTagName("Name"); + + List examples = new ArrayList<>(); + + for(int i = 0; i < nl.getLength(); i++){ + Node n = nl.item(i); + + NodeList sourceandtargets = n.getChildNodes(); + MultiExample me = null; + for(int j = 0; j < sourceandtargets.getLength(); j++){ + + Node st = sourceandtargets.item(j); + if(st.getNodeName().equals("SourceName")){ + me = new MultiExample(st.getTextContent().toLowerCase(), new ArrayList()); + }else if(st.getNodeName().equals("TargetName")){ + if(me != null) { + me.addTransliteratedWord(st.getTextContent()); + } + } + } + examples.add(me); + } + + return examples; + } + + public static void romanization() throws FileNotFoundException { + List lines = LineIO.read("/shared/corpora/transliteration/wikidata/wikidata.Russian.fixed"); + + String id = "Any-Arabic; NFD; [^\\p{Alnum}] Remove"; + //id = "Any-Latin; NFD"; + Transliterator t = Transliterator.getInstance(id); + + int jj = 0; + + List examples = new ArrayList<>(); + + for(String line : lines){ + + if(line.contains("#")){ + continue; + } + + jj++; + String[] parts = line.split("\t"); + + if(parts.length < 2){ + continue; + } + + // In wikipedia data, the foreign name comes first, English second. + String foreign = parts[0].toLowerCase(); + String english = parts[1].toLowerCase(); + String[] ftoks = foreign.split(" "); + String[] etoks = english.split(" "); + + if(ftoks.length != etoks.length){ + logger.error("Mismatching length of tokens: " + english); + continue; + } + + int numtoks = ftoks.length; + + for(int i = 0; i < numtoks; i++){ + String ftrans = t.transform(ftoks[i]); + ftoks[i] = ftrans; + + int mindist = Integer.MAX_VALUE; + String bestmatch = null; + + for(int j = 0; j < numtoks; j++){ + int d = LevensteinDistance.getLevensteinDistance(ftrans, etoks[j]); + if(d < mindist){ + // match etoks[j] with ftrans + bestmatch = etoks[j]; + mindist = d; + + // then take etoks[j] out of the running + } + } + + //System.out.print(ftrans + " : " + bestmatch + ", "); + examples.add(new Example(bestmatch, ftrans)); + + } + + if(jj%1000 == 0){ + System.out.println(jj); + } + + } + + System.out.println(examples.size()); + + Enumeration tids = t.getAvailableIDs(); + while(tids.hasMoreElements()){ + String e = tids.nextElement(); + //System.out.println(e); + } + } + + public static void getSize(String langname) throws FileNotFoundException { + String wikidata = "/shared/corpora/transliteration/wikidata/wikidata."; + List e = readWikiData(wikidata + langname); + + } + + + public static void main(String[] args) throws Exception { + //romanization(); + + String[] arabic_names = {"Urdu", "Arabic", "Egyptian_Arabic", "Mazandarani", "Pashto", "Persian", "Western_Punjabi"}; + String[] devanagari_names = {"Newar", "Hindi", "Marathi", "Nepali", "Sanskrit"}; + String[] cyrillic_names = {"Chuvash", "Bashkir", "Bulgarian", "Chechen", "Kirghiz", "Macedonian", "Russian", "Ukrainian"}; + + //for(String name : arabic_names){ + //System.out.println(name + " : " + WAVE("models/probs-"+name+"-Urdu.txt")); + //getSize(name); + //} + + String lang= "Arabic"; + String wikidata = "Data/wikidata." + lang; + + List allnames = LineIO.read("/Users/stephen/Dropbox/papers/NAACL2016/data/all-names2.txt"); + + List training = readWikiData(wikidata); + + training = training.subList(0, 2000); + + SPModel m = new SPModel(training); + m.Train(5); + + TopList res = m.Generate("stephen"); + System.out.println(res); + + List outlines = new ArrayList<>(); + + int i = 0; + for(String nameAndLabel : allnames){ + if(i%100 == 0){ + System.out.println(i); + } + i++; + + String[] s = nameAndLabel.split("\t"); + String name = s[0]; + String label = s[1]; + + String[] sname = name.split(" "); + + String line = ""; + for(String tok : sname){ + res = m.Generate(tok.toLowerCase()); + if(res.size() > 0) { + String topcand = res.getFirst().getSecond(); + line += topcand + " "; + }else{ + } + } + + if(line.trim().length() > 0) { + outlines.add(line.trim() + "\t" + label); + } + } + + LineIO.write("/Users/stephen/Dropbox/papers/NAACL2016/data/all-names-"+ lang +"2.txt", outlines); + + +// Transliterator t = Transliterator.getInstance("Any-am_FONIPA"); +// +// String result = t.transform("Stephen"); +// System.out.println(result); +// +// Enumeration tids = t.getAvailableIDs(); +// +// while(tids.hasMoreElements()){ +// String e = tids.nextElement(); +// System.out.println(e); +// } + + } + +} diff --git a/transliteration/src/main/resources/log4j.properties b/transliteration/src/main/resources/log4j.properties new file mode 100644 index 000000000..4a9f9365b --- /dev/null +++ b/transliteration/src/main/resources/log4j.properties @@ -0,0 +1,9 @@ +# suppress inspection "UnusedProperty" for whole file +# Root logger option +log4j.rootLogger=DEBUG, stdout + +# Direct log messages to stdout +log4j.appender.stdout=org.apache.log4j.ConsoleAppender +log4j.appender.stdout.Target=System.out +log4j.appender.stdout.layout=org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=%d{HH:mm:ss} %-5p %c{1}:%L - %m%n \ No newline at end of file diff --git a/transliteration/src/test/java/edu/illinois/cs/cogcomp/transliteration/TestTransliteration.java b/transliteration/src/test/java/edu/illinois/cs/cogcomp/transliteration/TestTransliteration.java new file mode 100644 index 000000000..edb5e7063 --- /dev/null +++ b/transliteration/src/test/java/edu/illinois/cs/cogcomp/transliteration/TestTransliteration.java @@ -0,0 +1,97 @@ +/** + * This software is released under the University of Illinois/Research and Academic Use License. See + * the LICENSE file in the root folder for details. Copyright (c) 2016 + * + * Developed by: The Cognitive Computation Group University of Illinois at Urbana-Champaign + * http://cogcomp.cs.illinois.edu/ + */ +package edu.illinois.cs.cogcomp.transliteration; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +/** + * Created by mayhew2 on 11/5/15. + */ +public class TestTransliteration { + + /** + * This is just a test to make sure that we can load and run everything. + */ + @Test + public void testModelLoad() + { + List examples = new ArrayList<>(); + examples.add(new Example("this", "this")); + + SPModel model = new SPModel(examples); + + boolean rom = false; + model.Train(1,rom, examples); + + System.out.println(model.Probability("this", "this")); + + } + + /** + * Test to ensure that Example equality is working correctly. + */ + @Test + public void testExampleEquality(){ + + HashSet ee = new HashSet<>(); + + Example e = new Example("John", "Smith"); + Example e2 = new Example("John", "Smith"); + + ee.add(e); + ee.add(e2); + + assert(ee.size() == 1); + + Example e3 = new Example("Smith", "John"); + ee.add(e3); + + assert(ee.size() == 2); + } + + /** + * Test to ensure that MultiExample equality is working correctly. + */ + @Test + public void testMultiExampleEquality(){ + + HashSet ee = new HashSet<>(); + + List l = new ArrayList<>(); + l.add("wut"); + l.add("why"); + MultiExample me = new MultiExample("John", l); + + List l2 = new ArrayList<>(); + l2.add("wut"); + l2.add("why"); + MultiExample me2 = new MultiExample("John", l2); + + ee.add(me); + ee.add(me2); + + assert(ee.size() == 1); + + List l3 = new ArrayList<>(); + // these are in the opposite order. + l3.add("why"); + l3.add("wut"); + MultiExample me3 = new MultiExample("John", l3); + ee.add(me3); + + assert(ee.size() == 2); + + } + +} diff --git a/transliteration/src/test/java/edu/illinois/cs/cogcomp/transliteration/TransliterationAnnotatorTest.java b/transliteration/src/test/java/edu/illinois/cs/cogcomp/transliteration/TransliterationAnnotatorTest.java new file mode 100644 index 000000000..68572b0de --- /dev/null +++ b/transliteration/src/test/java/edu/illinois/cs/cogcomp/transliteration/TransliterationAnnotatorTest.java @@ -0,0 +1,28 @@ +package edu.illinois.cs.cogcomp.transliteration; + +import edu.illinois.cs.cogcomp.annotation.AnnotatorException; +import edu.illinois.cs.cogcomp.annotation.TransliterationAnnotator; +import edu.illinois.cs.cogcomp.core.datastructures.ViewNames; +import edu.illinois.cs.cogcomp.core.datastructures.textannotation.TextAnnotation; +import edu.illinois.cs.cogcomp.core.utilities.DummyTextAnnotationGenerator; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +public class TransliterationAnnotatorTest { + + TransliterationAnnotator annotator = null; + @Before + public void setUp() throws Exception { + this.annotator = new TransliterationAnnotator(); + } + + @Test + public void testCleanText() throws AnnotatorException { + TextAnnotation ta = DummyTextAnnotationGenerator.generateAnnotatedTextAnnotation(false, 3); + annotator.getView(ta); + assertEquals(true, ta.hasView(ViewNames.TRANSLITERATION)); + System.out.println(ta.getView(ViewNames.TRANSLITERATION)); + } +}