Skip to content

Commit

Permalink
Cross-validation script: Add filter for extensions (#6490)
Browse files Browse the repository at this point in the history
Add filter for extensions
  • Loading branch information
DecimalTurn authored Sep 7, 2023
1 parent f36d835 commit 073af2e
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions script/cross-validation
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@ STDOUT.sync = true
STDERR.sync = true

def print_usage(out)
out.puts "Usage: #{$PROGRAM_NAME} [--all] [--test]"
out.puts "Usage: #{$PROGRAM_NAME} [--all] [--extensions=<list>] [--test]"
out.puts ''
out.puts 'Performs leave-one-out cross-validation of the classifier.'
out.puts 'By default, outputs results only for samples with ambiguous extensions.'
out.puts 'If the --all flag is given, all samples and languages are considered.'
out.puts ''
out.puts 'If the --extensions option is used, only the extensions specified in the comma-seperated list will be considered.'
out.puts 'Extensions in the list must include the starting dot.'
out.puts ''
out.puts 'The --test flag can be used to verify that the number of errors is acceptable.'
end

Expand All @@ -29,12 +32,15 @@ end

$all = false
$test = false
ARGV.each do |arg|
$exts = []
ARGV.each_with_index do |arg, index|
case arg
when '--all'
$all = true
when '--test'
$test = true
when /^--extensions=(.*)$/
$exts = $1.delete("'\"").split(',').map(&:strip)
else
STDERR.puts "Invalid command line argument: #{arg}"
STDERR.puts ''
Expand Down Expand Up @@ -70,6 +76,11 @@ end
def eval(sample)
return nil if $skip_extensions.include? sample[:extname]

# Apply extensions list filter
if $exts.any? && !$exts.include?(File.extname(sample[:path]))
return nil
end

# If --all is set, use all languages. Otherwise, get only languages that are
# ambiguous in terms of filename and extension.
if $all
Expand Down

0 comments on commit 073af2e

Please sign in to comment.