Skip to content

Commit

Permalink
DEV: Add covariate import wrapper function (#237)
Browse files Browse the repository at this point in the history
STYLE: Format with black
  • Loading branch information
NickEdwards7502 committed Sep 13, 2024
1 parent de29b45 commit fe2db4c
Showing 1 changed file with 48 additions and 22 deletions.
70 changes: 48 additions & 22 deletions python/varspark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
from varspark import java
from varspark.etc import find_jar
from varspark.featuresource import FeatureSource
from varspark.covariatesource import CovariateSource


class VarsparkContext(object):
"""The main entry point for VariantSpark functionality.
"""
"""The main entry point for VariantSpark functionality."""

@classmethod
def spark_conf(cls, conf=SparkConf()):
""" Adds the necessary option to the spark configuration.
"""Adds the necessary option to the spark configuration.
Note: In client mode these need to be setup up using --jars or --driver-class-path
"""
return conf.set("spark.jars", find_jar())
Expand All @@ -30,46 +31,71 @@ def __init__(self, ss, silent=False):
self.sql = SQLContext.getOrCreate(self.sc)
self._jsql = self.sql._jsqlContext
self._jvm = self.sc._jvm
self._vs_api = getattr(self._jvm, 'au.csiro.variantspark.api')
self._vs_api = getattr(self._jvm, "au.csiro.variantspark.api")
jss = ss._jsparkSession
self._jvsc = self._vs_api.VSContext.apply(jss)

setup_typecheck()

if not self.silent:
sys.stderr.write('Running on Apache Spark version {}\n'.format(self.sc.version))
sys.stderr.write(
"Running on Apache Spark version {}\n".format(self.sc.version)
)
if self.sc._jsc.sc().uiWebUrl().isDefined():
sys.stderr.write('SparkUI available at {}\n'.format(
self.sc._jsc.sc().uiWebUrl().get()))
sys.stderr.write(
"SparkUI available at {}\n".format(
self.sc._jsc.sc().uiWebUrl().get()
)
)
sys.stderr.write(
'Welcome to\n'
' _ __ _ __ _____ __ \n'
'| | / /___ ______(_)___ _____ / /_/ ___/____ ____ ______/ /__ \n'
'| | / / __ `/ ___/ / __ `/ __ \/ __/\__ \/ __ \/ __ `/ ___/ //_/ \n'
'| |/ / /_/ / / / / /_/ / / / / /_ ___/ / /_/ / /_/ / / / ,< \n'
'|___/\__,_/_/ /_/\__,_/_/ /_/\__//____/ .___/\__,_/_/ /_/|_| \n'
' /_/ \n')
"Welcome to\n"
" _ __ _ __ _____ __ \n"
"| | / /___ ______(_)___ _____ / /_/ ___/____ ____ ______/ /__ \n"
"| | / / __ `/ ___/ / __ `/ __ \/ __/\__ \/ __ \/ __ `/ ___/ //_/ \n"
"| |/ / /_/ / / / / /_/ / / / / /_ ___/ / /_/ / /_/ / / / ,< \n"
"|___/\__,_/_/ /_/\__,_/_/ /_/\__//____/ .___/\__,_/_/ /_/|_| \n"
" /_/ \n"
)

@params(self=object, vcf_file_path=str, min_partitions=int)
def import_vcf(self, vcf_file_path, min_partitions=0):
""" Import features from a VCF file.
"""
return FeatureSource(self._jvm, self._vs_api,
self._jsql, self.sql, self._jvsc.importVCF(vcf_file_path,
min_partitions))
"""Import features from a VCF file."""
return FeatureSource(
self._jvm,
self._vs_api,
self._jsql,
self.sql,
self._jvsc.importVCF(vcf_file_path, min_partitions),
)

@params(self=object, cov_file_path=str, cov_types=(list, dict))
def import_covariates(self, cov_file_path, cov_types):
"""Import covariates from a CSV file."""
if isinstance(cov_types, list):
types_rdd = self._jvm.SparkContext.parallelize(cov_types)
elif isinstance(cov_types, dict):
types_rdd = self._jvm.SparkContext.parallelize(cov_types.items())
else:
types_rdd = None
return CovariateSource(
self._jvm,
self._vs_api,
self._jsql,
self.sql,
self._jvsc.importCSV(inputFile=cov_file_path, optVariableTypes=types_rdd),
)

@params(self=object, label_file_path=str, col_name=str)
def load_label(self, label_file_path, col_name):
""" Loads the label source file
"""Loads the label source file
:param label_file_path: The file path for the label source file
:param col_name: the name of the column containing labels
"""
return self._jvsc.loadLabel(label_file_path, col_name)

def stop(self):
""" Shut down the VariantsContext.
"""
"""Shut down the VariantsContext."""

self.sc.stop()
self.sc = None
Expand Down

0 comments on commit fe2db4c

Please sign in to comment.