Skip to content

Commit

Permalink
Merge pull request #545 from giswqs/bugfix/rf_to_string-labels
Browse files Browse the repository at this point in the history
  • Loading branch information
giswqs authored Jun 30, 2021
2 parents d173181 + ed6ae71 commit 02dc20d
Showing 1 changed file with 53 additions and 7 deletions.
60 changes: 53 additions & 7 deletions geemap/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
import multiprocessing as mp
from functools import partial


def tree_to_string(estimator, feature_names):
def tree_to_string(estimator, feature_names, labels = None, output_mode="INFER"):
"""Function to convert a sklearn decision tree object to a string format that EE can interpret
args:
estimator (sklearn.tree.estimator): An estimator consisting of multiple decision tree classifiers. Expects object to contain estimators_ attribute
feature_names (list[str]): List of strings that define the name of features (i.e. bands) used to create the model
feature_names (Iterable[str]): List of strings that define the name of features (i.e. bands) used to create the model
kwargs:
labels (Iterable[numeric]): List of class labels to map outputs to, must be numeric values. If None, then raw outputs will be used. default = None
output_mode (str): the output mode of the estimator. Options are "INFER", "CLASSIFIATION", or "REGRESSION" (capitalization does not matter). default = "INFER"
returns:
tree_str (str): string representation of decision tree estimator
Expand All @@ -31,12 +34,32 @@ def tree_to_string(estimator, feature_names):
features = [feature_names[i] for i in feature_idx]

raw_vals = estimator.tree_.value
if raw_vals.ndim == 3:

# first check if user wants to infer output mode
# if so, reset the output_mode variable to a valid mode
if output_mode == "INFER":
if raw_vals.ndim == 3:
output_mode = "CLASSIFICATION"

elif raw_vals.ndim == 2:
output_mode = "REGRESSION"

else:
raise RuntimeError("Could not infer the output type from the estimator, please explicitly provide the output_mode ")

# second check on the output mode after the inference
if output_mode == "CLASSIFICATION":
# take argmax along class axis from values
values = np.squeeze(raw_vals.argmax(axis=-1))
elif raw_vals.ndim == 2:
if labels is not None:
index_labels = np.unique(values)
lookup = {idx:labels[i] for i,idx in enumerate(index_labels)}
values = [lookup[v] for v in values]

elif output_mode == "REGRESSION":
# take values and drop un needed axis
values = np.squeeze(raw_vals)

else:
raise RuntimeError(
"could not understand estimator type and parse out the values"
Expand Down Expand Up @@ -180,7 +203,7 @@ def tree_to_string(estimator, feature_names):
return tree_str


def rf_to_strings(estimator, feature_names, processes=2):
def rf_to_strings(estimator, feature_names, processes=2, output_mode="INFER"):
"""Function to convert a ensemble of decision trees into a list of strings. Wraps `tree_to_string`
args:
Expand All @@ -189,15 +212,38 @@ def rf_to_strings(estimator, feature_names, processes=2):
kwargs:
processess (int): number of cpu processes to spawn. Increasing processes will improve speed for large models. default = 2
output_mode (str): the output mode of the estimator. Options are "INFER", "CLASSIFIATION", or "REGRESSION" (capitalization does not matter). default = "INFER"
returns:
trees (list[str]): list of strings where each string represents a decision tree estimator and collectively represent an ensemble decision tree estimator (i.e. RandomForest)
"""

# force output mode to be capital
output_mode = output_mode.upper()

available_modes = ["INFER","CLASSIFICATION","REGRESSION"]

if output_mode not in available_modes:
raise ValueError(f"The provided output_mode is not available, please provide one from the following list: {available_modes}")

# extract out the estimator trees
estimators = estimator.estimators_

if output_mode == "INFER":
if estimator.criterion in ["gini","entropy"]:
class_labels = estimator.classes_
elif estimator.criterion in ["mse","mae"]:
class_labels = None
else:
raise RuntimeError("Could not infer the output type from the estimator, please explicitly provide the output_mode ")

elif output_mode == "CLASSIFICATION":
class_labels = estimator.classes_

else:
class_labels = None

# check that number of processors set to use is not more than available
if processes >= mp.cpu_count():
# if so, force to use only cpu count - 1
Expand All @@ -206,7 +252,7 @@ def rf_to_strings(estimator, feature_names, processes=2):
# run the tree extraction process in parallel
with mp.Pool(processes) as pool:
proc = pool.map_async(
partial(tree_to_string, feature_names=feature_names), estimators
partial(tree_to_string, feature_names=feature_names,labels=class_labels, output_mode=output_mode), estimators
)
trees = list(proc.get())

Expand Down

0 comments on commit 02dc20d

Please sign in to comment.