Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Centroid-based Classifier #5103

Merged
merged 18 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
421 changes: 315 additions & 106 deletions lib/linguist/classifier.rb

Large diffs are not rendered by default.

10 changes: 9 additions & 1 deletion lib/linguist/samples.rb
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,14 @@ def self.cache
# Hash of serialized samples object, uncached
def self.load_samples
serializer = defined?(Yajl) ? Yajl : JSON
serializer.load(File.read(PATH, encoding: 'utf-8'))
data = serializer.load(File.read(PATH, encoding: 'utf-8'))
# JSON serialization does not allow integer keys, we fix them here
for lang in data['centroids'].keys
fixed = data['centroids'][lang].to_a.map { |k,v| [k.to_i, v] }
data['centroids'][lang] = Hash[fixed]
end

data
end

# Public: Iterate over each sample.
Expand Down Expand Up @@ -106,6 +113,7 @@ def self.data
Classifier.train!(db, language_name, data)
end

Classifier.finalize_train! db
db['sha256'] = Linguist::SHA256.hexdigest(db)

db
Expand Down
2 changes: 1 addition & 1 deletion lib/linguist/sha256.rb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def self.hexdigest(obj)
digest = Digest::SHA256.new

case obj
when String, Symbol, Integer
when String, Symbol, Integer, Float
digest.update "#{obj.class}"
digest.update "#{obj}"
when TrueClass, FalseClass, NilClass
Expand Down
183 changes: 183 additions & 0 deletions samples/R/2.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
library(cowplot)
library(pheatmap)
library(dplyr)

set.seed(12323414)
options(stringsAsFactors = FALSE)
font_size <- 8

gtable_select <- function (x, ...)
{
matches <- c(...)
x$layout <- x$layout[matches, , drop = FALSE]
x$grobs <- x$grobs[matches]
x
}

gtable_stack <- function(g1, g2){
g1$grobs <- c(g1$grobs, g2$grobs)
g1$layout <- transform(g1$layout, z= z-max(z), name="g2")
g1$layout <- rbind(g1$layout, g2$layout)
g1
}

get_a <- function() {
d <- read.csv("data/2a.csv")

d[d$Hierarchy == "Kolodziejczyk", ]$Hierarchy <- "Kolodz."

d$Hierarchy <- factor(
d$Hierarchy,
levels = c("Biase", "Yan", "Goolam", "Deng", "Pollen1", "Pollen2",
"Kolodz.", "Treutlein", "Ting",
"Patel", "Usoskin1", "Usoskin2", "Usoskin3",
"Klein", "Zeisel")
)

d$Method <- factor(
d$Method,
levels = c("SC3", "tSNE+kmeans", "pcaReduce", "SNN-Cliq", "SINCERA", "SEURAT")
)

cols <- c("Biase" = "#bc80bd", "Treutlein" = "#8dd3c7", "Ting" = "#ffffb3",
"Yan" = "#ccebc5", "Goolam" = "#ffed6f", "Deng" = "#bebada",
"Pollen1" = "#fb8072", "Pollen2" = "#fb8072",
"Patel" = "#80b1d3", "Usoskin1" = "#fdb462", "Usoskin2" = "#fdb462",
"Usoskin3" = "#fdb462", "Kolodz." = "#bf812d",
"Klein" = "#b3de69", "Zeisel" = "#fccde5", "Macosko" = "#d9d9d9")

meth_cols <- c(
"SC3" = "#e41a1c",
"tSNE+kmeans" = "#377eb8",
"pcaReduce" = "#40E0D0",
"SNN-Cliq" = "#984ea3",
"SINCERA" = "#ff7f00",
"SEURAT" = "#ffff33"
)

d1 <- d %>%
group_by(Method, Hierarchy) %>%
dplyr::summarise(Median = median(ARI))

p <- ggplot(d, aes(x = 1, y = ARI, fill = Method, group = Method)) +
geom_bar(data = d1, aes(y = Median), position="dodge", stat="identity") +
geom_point(position = position_jitterdodge(jitter.width = 0.45, dodge.width = 0.9), size = 0.4) +
facet_wrap(ncol = 5, ~ Hierarchy) +
scale_fill_manual(values = meth_cols) +
scale_colour_manual(values = meth_cols) +
geom_hline(yintercept = 0.8) +
labs(x = "") +
theme_classic(base_size = font_size) +
theme(axis.ticks.x = element_blank(), axis.text.x=element_blank(),
axis.title.x=element_blank(), axis.line=element_blank(),
legend.key.size = unit(0.4, "cm")) +
annotate("segment", x=-Inf, xend=Inf, y=-Inf, yend=-Inf, color = "black")+
annotate("segment", x=-Inf, xend=-Inf, y=-Inf, yend=Inf, color = "black")


dummy <- ggplot(d, aes(x = 1, y = ARI, fill = Method)) +
facet_wrap(ncol = 5, ~ Hierarchy) +
geom_rect(aes(fill = Hierarchy), xmin=-Inf, xmax=Inf, ymin=-Inf, ymax=Inf) +
scale_fill_manual(values = cols) +
theme_minimal()

g1 <- ggplotGrob(p)
g2 <- ggplotGrob(dummy)

panels <- grepl(pattern="panel", g2$layout$name)
strips <- grepl(pattern="strip-t", g2$layout$name)
g2$layout$t[panels] <- g2$layout$t[panels] - 1
g2$layout$b[panels] <- g2$layout$b[panels] - 1

new_strips <- gtable_select(g2, panels | strips)

new_plot <- gtable_stack(g1, new_strips)
return(new_plot)
}

get_c <- function() {
cols <- c("Treutlein" = "#8dd3c7", "Ting" = "#ffffb3", "Deng" = "#bebada",
"Pollen2" = "#fb8072", "Patel" = "#80b1d3",
"Kolodziejczyk" = "#bf812d", "Usoskin3" = "#fdb462",
"Klein" = "#40E0D0", "Zeisel" = "#fccde5", "Macosko" = "#d9d9d9")

d <- read.csv("data/2c.csv")

d$Dataset <- factor(
d$Dataset,
levels = c(
"Deng",
"Pollen2",
"Kolodziejczyk",
"Patel",
"Usoskin3",
"Klein",
"Zeisel",
"Macosko"
)
)

d$Fraction <- factor(
d$Fraction,
levels = sort(unique(as.numeric(d$Fraction)))
)

p <- ggplot(d, aes(x = 1, ARI, fill = Dataset, color = Dataset)) +
geom_boxplot(position = position_dodge(width = 1.5), outlier.size = 0.8) +
geom_hline(yintercept = 0.8) +
labs(x = "# of training cells as % of N", y = "ARI") +
scale_fill_manual(values = cols) +
scale_colour_manual(values = cols) +
facet_grid(. ~ Fraction) +
theme_classic(base_size = font_size) +
theme(axis.ticks.x = element_blank(), axis.text.x=element_blank(),
axis.title.x=element_blank(), axis.line=element_blank(),
strip.background = element_rect(colour = "white"),
legend.key.size = unit(0.4, "cm")) +
ylim(0,1) +
annotate("segment", x=-Inf, xend=Inf, y=-Inf, yend=-Inf, color = "black")+
annotate("segment", x=-Inf, xend=-Inf, y=-Inf, yend=Inf, color = "black")
p <- ggdraw(p) +
draw_label("% of total # of cells\nin a training set",
fontface = "bold",
size = font_size-3,
x = 0.87, y = 0.93)
return(p)
}

get_d <- function() {
d <- readRDS("data/2d.rds")
ann <- data.frame(Stage = factor(d$cell.names, levels = unique(d$cell.names)))
anno_colors <- list(Stage = c("#A6CEE3", "#1F78B4", "#B2DF8A", "#33A02C",
"#FB9A99", "#FF00FF", "#FDBF6F", "#FF7F00",
"#CAB2D6", "#6A3D9A"))
names(anno_colors$Stage) <- levels(ann$Stage)
dat <- d$consensus
colnames(dat) <- d$cell.names
write.csv(dat[d$hc$order, d$hc$order], file = "data/2d.csv", quote = FALSE, row.names = FALSE)
p <- pheatmap(d$consensus,
cluster_rows = d$hc,
cluster_cols = d$hc,
cutree_rows = 10,
cutree_cols = 10,
treeheight_col = 9,
treeheight_row = 9,
annotation_col = ann,
annotation_colors = anno_colors,
show_rownames = F,
show_colnames = F,
fontsize = font_size,
annotation_names_col = F,
silent = TRUE)
return(p$gtable)
}

first_col <- plot_grid(get_a(), get_c(), nrow = 2, labels = c("a", "c"), rel_heights = c(2, 1))

second_col <- plot_grid(NULL, get_d(), nrow = 2, labels = c("b", "d"), rel_heights = c(1.5, 1))

plot_grid(first_col, second_col, ncol = 2)

ggsave("jpeg/2.jpeg", w = 9, h = 6)
ggsave("pdf/2.pdf", w = 9, h = 6)

4 changes: 0 additions & 4 deletions samples/R/hello-r.R

This file was deleted.

3 changes: 2 additions & 1 deletion script/cross-validation
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Number of acceptable classification errors.
# It should only be decreased.
ACCEPTABLE_ERRORS = 39
ACCEPTABLE_ERRORS = 19

# Number of acceptable classification errors when using --all.
# It should only be decreased.
Expand Down Expand Up @@ -111,6 +111,7 @@ def eval(sample)
train_samples.each do |train_sample|
Classifier.train!(db, train_sample[:language], train_sample[:tokens])
end
Classifier.finalize_train! db

# Get result.
results = Classifier.classify(db, sample[:data], languages)
Expand Down
4 changes: 0 additions & 4 deletions test/fixtures/AngelScript/ClassDef.as

This file was deleted.

10 changes: 7 additions & 3 deletions test/test_blob.rb
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def test_language
root = File.expand_path('../fixtures', __FILE__)
Dir.entries(root).each do |language|
next if language == '.' || language == '..' || language == 'Binary' ||
File.basename(language) == 'ace_modes.json'
File.basename(language) == 'ace_modes.json'

# Each directory contains test files of a language
dirname = File.join(root, language)
Expand All @@ -306,9 +306,13 @@ def test_language
elsif language == 'Generic'
assert !blob.language, "#{filepath} should not match a language"
else
assert blob.language, "No language for #{filepath}"
fs_name = blob.language.fs_name ? blob.language.fs_name : blob.language.name
assert_equal language, fs_name, blob.name
if allowed_failures.has_key? filepath
assert allowed_failures[filepath].include?(fs_name), filepath
else
assert blob.language, "No language for #{filepath}"
assert_equal language, fs_name, filepath
end
end
end
end
Expand Down
14 changes: 11 additions & 3 deletions test/test_classifier.rb
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def test_classify
Classifier.train! db, "Ruby", fixture("Ruby/foo.rb")
Classifier.train! db, "Objective-C", fixture("Objective-C/Foo.h")
Classifier.train! db, "Objective-C", fixture("Objective-C/Foo.m")
Classifier.finalize_train! db

results = Classifier.classify(db, fixture("Objective-C/hello.m"))
assert_equal "Objective-C", results.first[0]
Expand All @@ -26,17 +27,18 @@ def test_restricted_classify
Classifier.train! db, "Ruby", fixture("Ruby/foo.rb")
Classifier.train! db, "Objective-C", fixture("Objective-C/Foo.h")
Classifier.train! db, "Objective-C", fixture("Objective-C/Foo.m")
Classifier.finalize_train! db

results = Classifier.classify(db, fixture("Objective-C/hello.m"), ["Objective-C"])
assert_equal "Objective-C", results.first[0]

results = Classifier.classify(db, fixture("Objective-C/hello.m"), ["Ruby"])
assert_equal "Ruby", results.first[0]
assert results.empty?
lildude marked this conversation as resolved.
Show resolved Hide resolved
end

def test_instance_classify_empty
results = Classifier.classify(Samples.cache, "")
assert results.first[1] < 0.5, results.first.inspect
assert results.empty?
end

def test_instance_classify_nil
Expand All @@ -46,7 +48,11 @@ def test_instance_classify_nil
def test_classify_ambiguous_languages
# Failures are reasonable in some cases, such as when a file is fully valid in more than one language.
allowed_failures = {
# Valid C and C++
"#{samples_path}/C/rpc.h" => ["C", "C++"],
# Tricky samples
"#{samples_path}/C/syscalldefs.h" => ["C", "C++"],
"#{samples_path}/C++/Types.h" => ["C", "C++"],
}

# Skip extensions with catch-all rule
Expand All @@ -70,7 +76,9 @@ def test_classify_ambiguous_languages

results = Classifier.classify(Samples.cache, File.read(sample[:path]), languages)

if allowed_failures.has_key? sample[:path]
if results.empty?
assert false, "no results for #{sample[:path]}"
elsif allowed_failures.has_key? sample[:path]
assert allowed_failures[sample[:path]].include?(results.first[0]), "#{sample[:path]}\n#{results.inspect}"
else
assert_equal language.name, results.first[0], "#{sample[:path]}\n#{results.inspect}"
Expand Down
8 changes: 6 additions & 2 deletions test/test_file_blob.rb
Original file line number Diff line number Diff line change
Expand Up @@ -706,9 +706,13 @@ def test_language
elsif language == 'Generic'
assert !blob.language, "#{filepath} should not match a language"
else
assert blob.language, "No language for #{filepath}"
fs_name = blob.language.fs_name ? blob.language.fs_name : blob.language.name
assert_equal language, fs_name, blob.name
if allowed_failures.has_key? filepath
assert allowed_failures[filepath].include?(fs_name), filepath
else
assert blob.language, "No language for #{filepath}"
assert_equal language, fs_name, filepath
end
end
end
end
Expand Down
8 changes: 4 additions & 4 deletions test/test_repository.rb
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ def test_repo_git_attributes
assert !repo.breakdown_by_file["Ruby"].empty?

# Ensures the filename that contains unicode char is UTF-8 encoded and invalid chars scrubbed
assert repo.breakdown_by_file.has_key?("Perl")
assert repo.breakdown_by_file["Perl"].include?("test/fixtures/ba�r/file_ã.pl")
assert_equal "UTF-8", repo.breakdown_by_file["Perl"].first.encoding.to_s
assert repo.breakdown_by_file["Perl"].first.valid_encoding?
assert repo.breakdown_by_file.has_key?("Raku")
assert repo.breakdown_by_file["Raku"].include?("test/fixtures/ba�r/file_ã.pl")
assert_equal "UTF-8", repo.breakdown_by_file["Raku"].first.encoding.to_s
assert repo.breakdown_by_file["Raku"].first.valid_encoding?
end

def test_commit_with_git_attributes_data
Expand Down
9 changes: 6 additions & 3 deletions test/test_samples.rb
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@ def test_up_to_date
def test_verify
assert data = Samples.cache

assert_equal data['languages_total'], data['languages'].inject(0) { |n, (_, c)| n += c }
assert_equal data['tokens_total'], data['language_tokens'].inject(0) { |n, (_, c)| n += c }
assert_equal data['tokens_total'], data['tokens'].inject(0) { |n, (_, ts)| n += ts.inject(0) { |m, (_, c)| m += c } }
assert !data["vocabulary"].empty?
assert !data["icf"].empty?
assert !data["centroids"].empty?
assert_equal data["icf"].size, data["vocabulary"].size
assert !data["extnames"].empty?
assert !data["interpreters"].empty?
assert !data["filenames"].empty?
end

def test_ext_or_shebang
Expand Down
Loading