Skip to content

Commit

Permalink
DEV: Update python unit testing (#237)
Browse files Browse the repository at this point in the history
REFACTOR: Refactor to mirror changes to python wrapper

FEAT: Include FDR calculation in unit test
  • Loading branch information
NickEdwards7502 committed Sep 11, 2024
1 parent f6d40d4 commit 3356d9a
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions python/varspark/test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pyspark import SparkConf
from pyspark.sql import SparkSession

from varspark import VariantsContext
from varspark import VariantsContext, RFModelContext
from varspark.test import find_variants_jar, PROJECT_DIR

THIS_DIR = os.path.dirname(os.path.abspath(__file__))
Expand All @@ -28,7 +28,8 @@ def tearDownClass(self):


class VariantSparkAPITestCase(VariantSparkPySparkTestCase):

# self._ variables are only accessible from other tests if initialised here
# Would it be better to include model, importance, and fdr definitions here to support multiple unit tests?
def setUp(self):
self.spark = SparkSession(self.sc)
self.vc = VariantsContext(self.spark)
Expand All @@ -39,22 +40,25 @@ def test_variants_context_parameter_type(self):
self.assertEqual('keyword argument label_file_path = 123 doesn\'t match signature str',
str(cm.exception))

def test_importance_analysis_from_vcf(self):
def test_rfmodel(self):
label_data_path = os.path.join(PROJECT_DIR, 'data/chr22-labels.csv')
label = self.vc.load_label(label_file_path=label_data_path, col_name='22_16050678')
feature_data_path = os.path.join(PROJECT_DIR, 'data/chr22_1000.vcf')
features = self.vc.import_vcf(vcf_file_path=feature_data_path)

imp_analysis = features.importance_analysis(label, 200, None, True, 17, 50, 3)
rf = RFModelContext(self.spark, mtry_fraction=None, oob=True, seed=17, var_ordinal_levels=3)
rf.fit_trees(features, label, n_trees=200, batch_size=50)
imp_analysis = rf.importance_analysis()
imp_vars = imp_analysis.important_variables(20)
most_imp_var = imp_vars[0][0]
most_imp_var = imp_vars['variable'][0]
self.assertEqual('22_16050678_C_T', most_imp_var)
df = imp_analysis.variable_importance()
df = imp_analysis.variable_importance(normalized=True)
self.assertEqual('22_16050678_C_T',
str(df.orderBy('importance', ascending=False).collect()[0][0]))
oob_error = imp_analysis.oob_error()
self.assertAlmostEqual(0.004578754578754579, oob_error, 4)

str(df.sort_values(by='importance', ascending=False)['variant_id'].iloc[0]))
oob_error = rf.oob_error()
self.assertEqual(0.004578754578754579, oob_error)
fdrCalc = rf.get_lfdr()
_, fdr = fdrCalc.compute_fdr(countThreshold = 2, local_fdr_cutoff = 0.05)
self.assertEqual(0.0002976892628282768, fdr)

if __name__ == '__main__':
unittest.main(verbosity=2)

0 comments on commit 3356d9a

Please sign in to comment.