-
Notifications
You must be signed in to change notification settings - Fork 0
/
Classifier.java
212 lines (175 loc) · 9.12 KB
/
Classifier.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
package DNA.classification;
import java.io.*;
import java.util.*;
import java.nio.file.Files;
public class Classifier {
public static void main(String[] args) {
String kmersFolderPath = "Data/classification/kmers";
String datasetsFolderPath = "Data/classification/originalDatasets";
File kmersFolder = new File(kmersFolderPath);
File datasetsFolder = new File(datasetsFolderPath);
// First, calculate dataset sizes
Map<String, Long> datasetSizes = new HashMap<>();
for (File datasetFile : datasetsFolder.listFiles()) {
try {
String content = new String(Files.readAllBytes(datasetFile.toPath()));
datasetSizes.put(datasetFile.getName(), (long) content.length());
System.out.println("Dataset " + datasetFile.getName() +
" size: " + content.length() + " bases");
} catch (IOException e) {
e.printStackTrace();
}
}
File[] kmerFiles = kmersFolder.listFiles((dir, name) -> name.endsWith("_optimal.txt"));
if (kmerFiles != null) {
for (File kmerFile : kmerFiles) {
System.out.println("\nProcessing k-mer file: " + kmerFile.getName());
Set<String> kmers = readKmersFromFile(kmerFile);
String actualClass = kmerFile.getName().replace("_optimal.txt", "");
System.out.println("Actual class: " + actualClass);
// Store normalized frequencies for each kmer across datasets
Map<String, Map<String, Double>> kmerDatasetFrequencies = new HashMap<>();
// Process each dataset
for (File datasetFile : datasetsFolder.listFiles()) {
String datasetName = datasetFile.getName();
long datasetSize = datasetSizes.get(datasetName);
System.out.println("Counting kmers in dataset: " + datasetName);
try {
String content = new String(Files.readAllBytes(datasetFile.toPath()));
for (String kmer : kmers) {
kmerDatasetFrequencies.putIfAbsent(kmer, new HashMap<>());
StringBuilder modifiedContent = new StringBuilder(content);
int count = 0;
int index = 0;
char replacementChar = getReplacementChar(count);
while ((index = modifiedContent.indexOf(kmer, index)) != -1) {
count++;
modifiedContent.replace(index, index + kmer.length(),
String.valueOf(replacementChar));
index += 1;
}
// Calculate normalized frequency (percentage)
double frequency = (count * 100.0) / datasetSize;
kmerDatasetFrequencies.get(kmer).put(datasetName, frequency);
// Debug output
System.out.printf(" %s: %d occurrences (%.6f%% of %s)%n",
kmer, count, frequency, datasetName);
}
} catch (IOException e) {
e.printStackTrace();
}
}
// Classify kmers based on normalized frequencies
int correctClassifications = 0;
int totalKmers = kmers.size();
for (String kmer : kmers) {
Map<String, Double> datasetFrequencies = kmerDatasetFrequencies.get(kmer);
String predictedClass = getPredictedClass(datasetFrequencies);
if (predictedClass.equals(actualClass)) {
correctClassifications++;
}
// Detailed classification output
System.out.println("\nKmer: " + kmer);
System.out.println("Frequencies across datasets:");
datasetFrequencies.forEach((dataset, freq) ->
System.out.printf(" %s: %.6f%%%n", dataset, freq));
System.out.println("Predicted: " + predictedClass +
", Actual: " + actualClass +
", Correct: " + (predictedClass.equals(actualClass)));
}
double accuracy = (totalKmers > 0) ?
(correctClassifications * 100.0 / totalKmers) : 0.0;
System.out.printf("\nAccuracy for %s: %.2f%% (%d/%d correct)%n",
kmerFile.getName(), accuracy, correctClassifications, totalKmers);
}
}
}
private static String getPredictedClass(Map<String, Double> datasetFrequencies) {
String predictedClass = "";
double maxFrequency = -1;
for (Map.Entry<String, Double> entry : datasetFrequencies.entrySet()) {
if (entry.getValue() > maxFrequency) {
maxFrequency = entry.getValue();
predictedClass = entry.getKey();
}
}
return predictedClass;
}
// Read k-mers from the given file
private static Set<String> readKmersFromFile(File kmerFile) {
Set<String> kmers = new HashSet<>();
try (BufferedReader br = new BufferedReader(new FileReader(kmerFile))) {
String line;
while ((line = br.readLine()) != null) {
line = line.trim();
if (!line.isEmpty()) {
kmers.add(line);
}
}
} catch (IOException e) {
e.printStackTrace();
}
return kmers;
}
// Count occurrences of each k-mer in a dataset file
private static Map<String, Integer> countKmersInDataset(File datasetFile, Set<String> kmers) {
Map<String, Integer> kmerCounts = new HashMap<>();
try {
String content = new String(Files.readAllBytes(datasetFile.toPath()));
for (String kmer : kmers) {
int count = countOccurrences(content, kmer);
kmerCounts.put(kmer, count);
}
} catch (IOException e) {
e.printStackTrace();
}
return kmerCounts;
}
// Count occurrences of a k-mer in a dataset string
private static int countOccurrences(String text, String kmer) {
int count = 0;
int index = 0;
while ((index = text.indexOf(kmer, index)) != -1) {
count++;
index += 1; // Move one position at a time to catch overlapping occurrences
}
return count;
}
// Classify the k-mers based on the frequency of their occurrence in the datasets
private static int classifyKmers(Set<String> kmers, Map<String, Map<String, Integer>> datasetKmerCounts) {
int correctClassifications = 0;
for (String kmer : kmers) {
String predictedClass = getPredictedClass(kmer, datasetKmerCounts);
String actualClass = getActualClass(kmer);
if (predictedClass.equals(actualClass)) {
correctClassifications++;
}
}
return correctClassifications;
}
// Determine the predicted class based on the highest frequency of a k-mer in a dataset
private static String getPredictedClass(String kmer, Map<String, Map<String, Integer>> datasetKmerCounts) {
String predictedClass = "";
int maxCount = -1;
for (Map.Entry<String, Map<String, Integer>> entry : datasetKmerCounts.entrySet()) {
String datasetName = entry.getKey();
int count = entry.getValue().getOrDefault(kmer, 0);
// Debug output
System.out.println(" " + kmer + " occurs " + count + " times in " + datasetName);
if (count > maxCount) {
maxCount = count;
predictedClass = datasetName;
}
}
return predictedClass;
}
// Determine the actual class for a k-mer based on the dataset
private static String getActualClass(String kmer) {
// Extract the actual class from the k-mer file name (assuming k-mer file names correspond to dataset names)
return kmer.split("\\.")[0]; // this assumes your dataset files have the same name as the k-mer file
}
// Returns a unique character for marking counted k-mers
private static char getReplacementChar(int count) {
return (char) ('0' + (count % 10));
}
}