Skip to content

Commit

Permalink
DEV: Updated varspark python wrapper (#237)
Browse files Browse the repository at this point in the history
REFACTOR: Removed FeatureSource and
ImportanceAnalysis classes from core

REFACTOR: Added FeatureSource import so features
can be returned as a class instantiation
  • Loading branch information
NickEdwards7502 committed Sep 11, 2024
1 parent 5edfbfa commit 80a9c59
Showing 1 changed file with 1 addition and 84 deletions.
85 changes: 1 addition & 84 deletions python/varspark/core.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import sys
from random import randint

from pyspark import SparkConf
from pyspark.sql import SQLContext
from typedecorator import params, Nullable, setup_typecheck

from varspark import java
from varspark.etc import find_jar

from varspark.featuresource import FeatureSource

class VarsparkContext(object):
"""The main entry point for VariantSpark functionality.
Expand Down Expand Up @@ -78,85 +77,3 @@ def stop(self):

# Deprecated
VariantsContext = VarsparkContext


class FeatureSource(object):

def __init__(self, _jvm, _vs_api, _jsql, sql, _jfs):
self._jfs = _jfs
self._jvm = _jvm
self._vs_api = _vs_api
self._jsql = _jsql
self.sql = sql

@params(self=object, label_source=object, n_trees=Nullable(int), mtry_fraction=Nullable(float),
oob=Nullable(bool), seed=Nullable(int), batch_size=Nullable(int),
var_ordinal_levels=Nullable(int), max_depth=int, min_node_size=int)
def importance_analysis(self, label_source, n_trees=1000, mtry_fraction=None,
oob=True, seed=None, batch_size=100, var_ordinal_levels=3,
max_depth=java.MAX_INT, min_node_size=1):
"""Builds random forest classifier.
:param label_source: The ingested label source
:param int n_trees: The number of trees to build in the forest.
:param float mtry_fraction: The fraction of variables to try at each split.
:param bool oob: Should OOB error be calculated.
:param int seed: Random seed to use.
:param int batch_size: The number of trees to build in one batch.
:param int var_ordinal_levels:
:return: Importance analysis model.
:rtype: :py:class:`ImportanceAnalysis`
"""
vs_algo = self._jvm.au.csiro.variantspark.algo
jrf_params = vs_algo.RandomForestParams(bool(oob),
java.jfloat_or(
mtry_fraction),
True, java.NAN, True,
java.jlong_or(seed,
randint(
java.MIN_LONG,
java.MAX_LONG)),
max_depth,
min_node_size, False,
0)
jia = self._vs_api.ImportanceAnalysis(self._jsql, self._jfs, label_source,
jrf_params, n_trees, batch_size, var_ordinal_levels)
return ImportanceAnalysis(jia, self.sql)


class ImportanceAnalysis(object):
""" Model for random forest based importance analysis
"""

def __init__(self, _jia, sql):
self._jia = _jia
self.sql = sql

@params(self=object, limit=Nullable(int))
def important_variables(self, limit=10):
""" Gets the top limit important variables as a list of tuples (name, importance) where:
- name: string - variable name
- importance: double - gini importance
"""
jimpvarmap = self._jia.importantVariablesJavaMap(limit)
return sorted(jimpvarmap.items(), key=lambda x: x[1], reverse=True)

def oob_error(self):
""" OOB (Out of Bag) error estimate for the model
:rtype: float
"""
return self._jia.oobError()

def variable_importance(self):
""" Returns a DataFrame with the gini importance of variables.
The DataFrame has two columns:
- variable: string - variable name
- importance: double - gini importance
"""
jdf = self._jia.variableImportance()
jdf.count()
jdf.createTempView("df")
return self.sql.table("df")

0 comments on commit 80a9c59

Please sign in to comment.