pair : examples)
+ fExamples.add(new Triple<>(pair.getFirst(), pair.getSecond(), 1.0));
+
+ return fExamples;
+ }
+
+ /**
+ * Calculates a probability table for P(String2 | String1)
+ * This normalizes by source counts for each production.
+ *
+ * FIXME: This is identical to... WikiTransliteration.NormalizeBySourceSubstring
+ *
+ * @param ?
+ * @return
+ */
+ public static HashMap PSecondGivenFirst(HashMap counts) {
+ // counts of first words in productions.
+ HashMap totals1 = WikiTransliteration.GetAlignmentTotals1(counts);
+
+ HashMap result = new HashMap<>(counts.size());
+ for (Production prod : counts.keySet()) // loop over all productions
+ {
+ double prodcount = counts.get(prod);
+ double sourcecounts = totals1.get(prod.getFirst()); // be careful of unboxing!
+ double value = sourcecounts == 0 ? 0 : (prodcount / sourcecounts);
+ result.put(prod, value);
+ }
+
+ return result;
+ }
+
+
+ private static HashMap SumNormalize(HashMap vector) {
+ HashMap result = new HashMap<>(vector.size());
+ double sum = 0;
+ for (double value : vector.values()) {
+ sum += value;
+ }
+
+ for (Production key : vector.keySet()) {
+ Double value = vector.get(key);
+ result.put(key, value / sum);
+ }
+
+ return result;
+ }
+
+
+ /**
+ * Does this not need to have access to ngramsize? No. It gets all ngrams so it can backoff.
+ *
+ * By default, this includes padding.
+ *
+ * @param examples
+ * @param maxSubstringLength
+ * @return
+ */
+ public static HashMap GetNgramCounts(List examples, int maxSubstringLength) {
+ return WikiTransliteration.GetNgramCounts(1, maxSubstringLength, examples, true);
+ }
+
+ /**
+ * Given a wikidata file, this gets all the words in the foreign language for the language model.
+ *
+ * @param fname
+ * @return
+ */
+ public static List getForeignWords(String fname) throws FileNotFoundException {
+ List lines = LineIO.read(fname);
+ List words = new ArrayList<>();
+
+ for (String line : lines) {
+ String[] parts = line.trim().split("\t");
+ String foreign = parts[0];
+
+ String[] fsplit = foreign.split(" ");
+ for (String word : fsplit) {
+ words.add(word);
+ }
+ }
+
+ return words;
+ }
+
+ /**
+ * Given a wikidata file, this gets all the words in the foreign language for the language model.
+ *
+ * @return
+ */
+ public static List getForeignWords(List examples) throws FileNotFoundException {
+
+ List words = new ArrayList<>();
+
+ for (Example e : examples) {
+ words.add(e.getTransliteratedWord());
+ }
+
+ return words;
+ }
+
+
+ public static List> GetTabDelimitedPairs(String filename) throws FileNotFoundException {
+ List> result = new ArrayList<>();
+
+ for (String line : LineIO.read(filename)) {
+ String[] pair = line.trim().split("\t");
+ if (pair.length != 2) continue;
+ result.add(new Pair<>(pair[0].trim().toLowerCase(), pair[1].trim().toLowerCase()));
+ }
+
+ return result;
+ }
+
+
+ static HashMap> maxCache = new HashMap<>();
+
+ /**
+ * This returns a map of productions to counts. These are counts over the entire training corpus. These are all possible
+ * productions seen in training data. If a production does not show up in training, it will not be seen here.
+ *
+ * normalization parameter decides if it is normalized (typically not).
+ *
+ * @param maxSubstringLength1
+ * @param maxSubstringLength2
+ * @param examples
+ * @param probs
+ * @param weightingMode
+ * @param normalization
+ * @param getExampleCounts
+ * @return
+ */
+ static HashMap MakeRawAlignmentTable(int maxSubstringLength1, int maxSubstringLength2, List> examples, HashMap probs, WeightingMode weightingMode, WikiTransliteration.NormalizationMode normalization, boolean getExampleCounts) {
+ InternDictionary internTable = new InternDictionary<>();
+ HashMap counts = new HashMap<>();
+
+ List>> exampleCounts = (getExampleCounts ? new ArrayList>>(examples.size()) : null);
+
+ int alignmentCount = 0;
+ for (Triple example : examples) {
+ String sourceWord = example.getFirst();
+ String bestWord = example.getSecond(); // bestWord? Shouldn't it be target word?
+ if (sourceWord.length() * maxSubstringLength2 >= bestWord.length() && bestWord.length() * maxSubstringLength1 >= sourceWord.length()) {
+ alignmentCount++;
+
+ HashMap wordCounts;
+
+ if (weightingMode == WeightingMode.FindWeighted && probs != null)
+ wordCounts = WikiTransliteration.FindWeightedAlignments(sourceWord, bestWord, maxSubstringLength1, maxSubstringLength2, probs, internTable, normalization);
+ //wordCounts = WikiTransliteration.FindWeightedAlignmentsAverage(sourceWord, bestWord, maxSubstringLength1, maxSubstringLength2, probs, internTable, true, normalization);
+ else if (weightingMode == WeightingMode.CountWeighted)
+ wordCounts = WikiTransliteration.CountWeightedAlignments(sourceWord, bestWord, maxSubstringLength1, maxSubstringLength2, probs, internTable, normalization, false);
+ else if (weightingMode == WeightingMode.MaxAlignment) {
+
+ HashMap cached = new HashMap<>();
+ Production p = new Production(sourceWord, bestWord);
+ if (maxCache.containsKey(p)) {
+ cached = maxCache.get(p);
+ }
+
+ Dictionaries.AddTo(probs, cached, -1.);
+
+ wordCounts = WikiTransliteration.CountMaxAlignments(sourceWord, bestWord, maxSubstringLength1, probs, internTable, false);
+ maxCache.put(new Production(sourceWord, bestWord), wordCounts);
+
+ Dictionaries.AddTo(probs, cached, 1);
+ } else if (weightingMode == WeightingMode.MaxAlignmentWeighted)
+ wordCounts = WikiTransliteration.CountMaxAlignments(sourceWord, bestWord, maxSubstringLength1, probs, internTable, true);
+ else {//if (weightingMode == WeightingMode.None || weightingMode == WeightingMode.SuperficiallyWeighted)
+ // This executes if probs is null
+ wordCounts = WikiTransliteration.FindAlignments(sourceWord, bestWord, maxSubstringLength1, maxSubstringLength2, internTable, normalization);
+ }
+
+ if (weightingMode == WeightingMode.SuperficiallyWeighted && probs != null) {
+ wordCounts = SumNormalize(Dictionaries.MultiplyDouble(wordCounts, probs));
+ }
+
+ Dictionaries.AddTo(counts, wordCounts, example.getThird());
+
+ if (getExampleCounts) {
+ List> curExampleCounts = new ArrayList<>(wordCounts.size());
+ for (Production key : wordCounts.keySet()) {
+ Double value = wordCounts.get(key);
+ curExampleCounts.add(new Pair<>(key, value));
+ }
+
+ exampleCounts.add(curExampleCounts);
+ }
+ } else if (getExampleCounts) {
+ exampleCounts.add(null);
+ }
+ }
+
+ return counts;
+ }
+
+
+ public static SparseDoubleVector, String>> PSecondGivenFirst(SparseDoubleVector, String>> productionProbs) {
+ SparseDoubleVector, String>> result = new SparseDoubleVector<>();
+ SparseDoubleVector> totals = new SparseDoubleVector<>();
+ for (Pair, String> key : productionProbs.keySet()) {
+ Double value = productionProbs.get(key);
+ totals.put(key.getFirst(), totals.get(key.getFirst()) + value);
+ }
+
+ for (Pair, String> key : productionProbs.keySet()) {
+ Double value = productionProbs.get(key);
+ result.put(key, value / totals.get(key.getFirst()));
+ }
+
+ return result;
+
+ }
+
+
+ /**
+ * Calculates a probability table for P(String2 | String1)
+ *
+ * @param counts
+ * @return
+ */
+ public static SparseDoubleVector, String>> PSecondGivenFirstTriple(SparseDoubleVector, String>> counts) {
+ SparseDoubleVector> totals = new SparseDoubleVector<>();
+ for (Pair, String> key : counts.keySet()) {
+ Double value = counts.get(key);
+ totals.put(key.getFirst(), totals.get(key.getFirst()) + value);
+ }
+
+ SparseDoubleVector, String>> result = new SparseDoubleVector, String>>(counts.size());
+ for (Pair, String> key : counts.keySet()) {
+ Double value = counts.get(key);
+ double total = totals.get(key.getFirst());
+ result.put(key, total == 0 ? 0 : value / total);
+ }
+
+ return result;
+ }
+
+ /**
+ * Returns a random subset of a list; the list provided is modified to remove the selected items.
+ *
+ * @param wordList
+ * @param count
+ * @param seed
+ * @return
+ */
+ public static List> GetRandomPartOfList(List> wordList, int count, int seed) {
+ Random r = new Random(seed);
+
+ List> randomList = new ArrayList<>(count);
+
+ for (int i = 0; i < count; i++) {
+
+ int index = r.nextInt(wordList.size()); //r.Next(wordList.size());
+ randomList.add(wordList.get(index));
+ wordList.remove(index);
+ }
+
+ return randomList;
+ }
+
+ public static List GetListValues(List> wordList) {
+ List valueList = new ArrayList<>(wordList.size());
+ for (Pair pair : wordList)
+ valueList.add(pair.getSecond());
+
+ return valueList;
+ }
+
+
+ static double Choose(double n, double k) {
+ double result = 1;
+
+ for (double i = Math.max(k, n - k) + 1; i <= n; ++i)
+ result *= i;
+
+ for (double i = 2; i <= Math.min(k, n - k); ++i)
+ result /= i;
+
+ return result;
+ }
+
+ public static double[][] SegmentationCounts(int maxLength) {
+ double[][] result = new double[maxLength][];
+ for (int i = 0; i < maxLength; i++) {
+ result[i] = new double[i + 1];
+ for (int j = 0; j <= i; j++)
+ result[i][j] = Choose(i, j);
+ }
+
+ return result;
+ }
+
+ public static double[][] SegSums(int maxLength) {
+ double[][] segmentationCounts = SegmentationCounts(maxLength);
+ double[][] result = new double[maxLength][];
+ for (int i = 0; i < maxLength; i++) {
+ result[i] = new double[maxLength];
+ for (int j = 0; j < maxLength; j++) {
+ int minIJ = Math.min(i, j);
+ for (int k = 0; k <= minIJ; k++)
+ result[i][j] += segmentationCounts[i][k] * segmentationCounts[j][k];// *Math.Pow(0.5, k + 1);
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Number of possible segmentations.
+ */
+ public static double[][] segSums = SegSums(40);
+
+ public static void DiscoveryEM(int iterations, List candidateWords, List> trainingPairs, HashMap> testingPairs, TransliterationModel model) {
+ List> trainingTriples = ConvertExamples(trainingPairs);
+
+ for (int i = 0; i < iterations; i++) {
+ System.out.println("Iteration #" + i);
+
+ long startTime = System.nanoTime();
+ System.out.print("Training...");
+ model = model.LearnModel(trainingTriples);
+ long endTime = System.nanoTime();
+ System.out.println("Finished in " + (startTime - endTime) / (1000000 * 1000) + " seconds.");
+
+ DiscoveryEvaluation(testingPairs, candidateWords, model);
+ }
+
+ System.out.println("Finished.");
+ }
+
+
+ public static String NormalizeHebrew(String word) {
+ word = word.replace('ן', 'נ');
+ word = word.replace('ך', 'כ');
+ word = word.replace('ץ', 'צ');
+ word = word.replace('ם', 'מ');
+ word = word.replace('ף', 'פ');
+
+ return word;
+ }
+
+ public static List> NormalizeHebrew(List> pairs) {
+ List> result = new ArrayList<>(pairs.size());
+
+ for (int i = 0; i < pairs.size(); i++)
+ result.add(new Pair<>(pairs.get(i).getFirst(), NormalizeHebrew(pairs.get(i).getSecond())));
+
+ return result;
+ }
+
+ /**
+ * This is still broken...
+ *
+ * @param path
+ * @return
+ * @throws FileNotFoundException
+ */
+ static HashMap> GetAlexData(String path) throws FileNotFoundException {
+
+ HashMap> result = new HashMap<>();
+
+ // TODO: this is all broken.
+// ArrayList data = LineIO.read(path);
+//
+// for (String line : data)
+// {
+// if (line.length() == 0) continue;
+//
+// Match match = Regex.Match(line, "(?\\w+)\t(?\\w+)(?: {(?:-(?\\w*?)(?:(?:, )|}))+)?", RegexOptions.Compiled);
+//
+// String russianRoot = match.Groups["rroot"].Value;
+// if (russianRoot.length() == 0)
+// System.out.println("Parse error");
+//
+// List russianList = new ArrayList<>();
+//
+// //if (match.Groups["rsuf"].Captures.Count == 0)
+// russianList.Add(russianRoot.toLower()); //root only
+// //else
+// for (Capture capture : match.Groups["rsuf"].Captures)
+// russianList.add((russianRoot + capture.Value).ToLower());
+//
+// result[match.Groups["eng"].Value.ToLower()] = russianList;
+//
+// }
+
+ return result;
+ }
+
+
+ public static HashMap PruneProbs(int topK, HashMap probs) {
+ HashMap>> lists = new HashMap<>();
+ for (Production key : probs.keySet()) {
+ Double value = probs.get(key);
+
+ if (!lists.containsKey(key.getFirst())) {
+ lists.put(key.getFirst(), new ArrayList>());
+ }
+
+ lists.get(key.getFirst()).add(new Pair<>(key.getSecond(), value));
+ }
+
+ HashMap result = new HashMap<>();
+ for (String key : lists.keySet()) {
+ List> value = lists.get(key);
+ Collections.sort(value, new Comparator>() {
+ @Override
+ public int compare(Pair o1, Pair o2) {
+ double v = o2.getSecond() - o1.getSecond();
+ if (v > 0) {
+ return 1;
+ } else if (v < 0) {
+ return -1;
+ } else {
+ return 0;
+ }
+ }
+ });
+ int toAdd = Math.min(topK, value.size());
+ for (int i = 0; i < toAdd; i++) {
+ result.put(new Production(key, value.get(i).getFirst()), value.get(i).getSecond());
+ }
+ }
+
+ return result;
+ }
+
+ private static void DiscoveryEvaluation(HashMap> testingPairs, List