Skip to content

Commit

Permalink
DEV: Created standalone ImportanceAnalysis class in
Browse files Browse the repository at this point in the history
separate wrapper file (#237)

REFACTOR: Updated important_variables and variable_importance
methods to convert to pandas DataFrames
  • Loading branch information
NickEdwards7502 committed Sep 11, 2024
1 parent 4560998 commit b8b39fd
Showing 1 changed file with 39 additions and 0 deletions.
39 changes: 39 additions & 0 deletions python/varspark/importanceanalysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pandas as pd
from typedecorator import params, Nullable

class ImportanceAnalysis(object):
def __init__(self, _jia, sql):
self._jia = _jia
self.sql = sql

@params(self=object, limit=Nullable(int), normalized=Nullable(bool))
def important_variables(self, limit=10, normalized=False):
""" Gets the top limit important variables
:param (int) limit: Indicates how many of the most important variables to return
:param (bool) normalized: Indicates whether to return normalized importances
:return topimportances (pd.DataFrame): Dataframe of most important variables containing a
variant_id and its corresponding importance.
"""
jimpvarmap = self._jia.importantVariablesJavaMap(limit, normalized)
jimpvarmapsorted = sorted(jimpvarmap.items(), key=lambda x: x[1], reverse=True)
topimportances = pd.DataFrame(jimpvarmapsorted, columns=['variable', 'importance'])
return topimportances

@params(self=object, precision=Nullable(int), normalized=Nullable(bool))
def variable_importance(self, precision=None, normalized=False):
""" Returns a DataFrame with the gini importance of variables.
:param (int) precision: Maximum floating point precision to return
:param (bool) normalized: Indicates whether to return normalized importances
:return importances (pd.DataFrame): DataFrame of variable importances containing variant_id, importance, and split count
"""
jdf = self._jia.variableImportance(normalized)
jdf.count()
jdf.createOrReplaceTempView("df")
importances = self.sql.table("df").toPandas()
if precision is not None:
importances['importance'] = importances['importance'].apply(lambda x: round(x, precision))
return importances

0 comments on commit b8b39fd

Please sign in to comment.