Skip to content

Commit

Permalink
DEV: Add wrapper functions for covariate support (#237)
Browse files Browse the repository at this point in the history
FEAT: Add wrapper class for importing covariates

FEAT: Add wrapper class for unioning features and covariates
  • Loading branch information
NickEdwards7502 committed Sep 19, 2024
1 parent d671f35 commit 209a463
Showing 1 changed file with 48 additions and 11 deletions.
59 changes: 48 additions & 11 deletions python/varspark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from varspark import java
from varspark.etc import find_jar
from varspark.featuresource import FeatureSource
from varspark.covariatesource import CovariateSource


class VarsparkContext(object):
Expand Down Expand Up @@ -68,21 +67,59 @@ def import_vcf(self, vcf_file_path, min_partitions=0):
self._jvsc.importVCF(vcf_file_path, min_partitions),
)

@params(self=object, cov_file_path=str, cov_types=(list, dict))
def import_covariates(self, cov_file_path, cov_types):
"""Import covariates from a CSV file."""
if isinstance(cov_types, list):
types_rdd = self._jvm.SparkContext.parallelize(cov_types)
elif isinstance(cov_types, dict):
types_rdd = self._jvm.SparkContext.parallelize(cov_types.items())
@params(
self=object,
cov_file_path=str,
cov_types=Nullable(dict),
transposed=Nullable(bool),
)
def import_covariates(self, cov_file_path, cov_types=None, transposed=False):
"""Import covariates from a CSV file.
:param cov_file_path: The file path for covariate csv file
:param cov_types Dict[String]:
A dictionary specifying types for each covariate, where the key is the variable name
and the value is the type. The value can be one of the following:
- CONTINUOUS: A continuous variable type.
- DISCRETE: A discrete variable type.
- NOMINAL: A nominal variable type.
- ORDINAL: An ordinal variable type.
- ORDINAL(order_count): Specifies the number of ordered levels, where `order_count` represents the number of levels.
- NOMINAL(class_count): Specifies the number of distinct classes, where `class_count` represents the number of categories.
See VariableType.scala for more information.
:param transposed bool: Whether or not the covariate csv file is transposed
"""
if cov_types is not None:
cov_types_list = [f"{k},{c}" for k, c in cov_types.items()]
_jctypes = self._jvm.java.util.ArrayList()
for item in cov_types_list:
_jctypes.add(item)
else:
_jctypes = None
if transposed:
_jcs = self._jvsc.importTransposedCSV(cov_file_path, cov_types_list)
else:
types_rdd = None
return CovariateSource(
_jcs = self._jvsc.importStdCSV(cov_file_path)
return FeatureSource(
self._jvm,
self._vs_api,
self._jsql,
self.sql,
_jcs,
)

@params(self=object, feature_source=FeatureSource, covariate_source=FeatureSource)
def union_features_and_covariates(self, feature_source, covariate_source):
return FeatureSource(
self._jvm,
self._vs_api,
self._jsql,
self.sql,
self._jvsc.importCSV(inputFile=cov_file_path, optVariableTypes=types_rdd),
self._jvsc.unionFeaturesAndCovariates(
feature_source._jfs, covariate_source._jfs
),
)

@params(self=object, label_file_path=str, col_name=str)
Expand Down

0 comments on commit 209a463

Please sign in to comment.