Skip to content

Commit

Permalink
bug fix where output labels from rf_to_string for classification were…
Browse files Browse the repository at this point in the history
… indices not labels
  • Loading branch information
KMarkert committed Jun 25, 2021
1 parent dd82efe commit 200b39c
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions geemap/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import multiprocessing as mp
from functools import partial


def tree_to_string(estimator, feature_names):
def tree_to_string(estimator, feature_names, labels = None):
"""Function to convert a sklearn decision tree object to a string format that EE can interpret
args:
estimator (sklearn.tree.estimator): An estimator consisting of multiple decision tree classifiers. Expects object to contain estimators_ attribute
feature_names (list[str]): List of strings that define the name of features (i.e. bands) used to create the model
feature_names (Iterable[str]): List of strings that define the name of features (i.e. bands) used to create the model
labels (Iterable): List of class labels to
returns:
tree_str (str): string representation of decision tree estimator
Expand All @@ -34,6 +34,11 @@ def tree_to_string(estimator, feature_names):
if raw_vals.ndim == 3:
# take argmax along class axis from values
values = np.squeeze(raw_vals.argmax(axis=-1))
if labels is not None:
index_labels = np.unique(values)
lookup = {idx:labels[i] for i,idx in enumerate(index_labels)}
values = [lookup[v] for v in values]

elif raw_vals.ndim == 2:
# take values and drop un needed axis
values = np.squeeze(raw_vals)
Expand Down Expand Up @@ -197,6 +202,7 @@ def rf_to_strings(estimator, feature_names, processes=2):

# extract out the estimator trees
estimators = estimator.estimators_
class_labels = estimator.classes_

# check that number of processors set to use is not more than available
if processes >= mp.cpu_count():
Expand All @@ -206,7 +212,7 @@ def rf_to_strings(estimator, feature_names, processes=2):
# run the tree extraction process in parallel
with mp.Pool(processes) as pool:
proc = pool.map_async(
partial(tree_to_string, feature_names=feature_names), estimators
partial(tree_to_string, feature_names=feature_names,labels=class_labels), estimators
)
trees = list(proc.get())

Expand Down

0 comments on commit 200b39c

Please sign in to comment.