Skip to content

Commit

Permalink
Add predict() to simple tree and simple RF models
Browse files Browse the repository at this point in the history
This avoid creating intermediate tables.
  • Loading branch information
markotoplak committed Dec 23, 2022
1 parent e75f21c commit 28c90fb
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
11 changes: 8 additions & 3 deletions Orange/classification/simple_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,14 @@ def learn(self, learner, data):
tree.seed = learner.seed + i
self.estimators_.append(tree(data))

def predict_storage(self, data):
p = np.zeros((data.X.shape[0], self.cls_vals))
def predict(self, X):
p = np.zeros((X.shape[0], self.cls_vals))
X = np.ascontiguousarray(X) # so that it is a no-op for individual trees
for tree in self.estimators_:
p += tree(data, tree.Probs)
# SimpleTrees do not have preprocessors and domain conversion
# was already handled within this class so we can call tree.predict() directly
# instead of going through tree.__call__
_, pt = tree.predict(X)
p += pt
p /= len(self.estimators_)
return p.argmax(axis=1), p
4 changes: 2 additions & 2 deletions Orange/classification/simple_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ def __init__(self, learner, data):
learner.bootstrap,
learner.seed)

def predict_storage(self, data):
X = np.ascontiguousarray(data.X)
def predict(self, X):
X = np.ascontiguousarray(X)
if self.type == Classification:
p = np.zeros((X.shape[0], self.cls_vals))
_tree.predict_classification(
Expand Down
11 changes: 8 additions & 3 deletions Orange/regression/simple_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,14 @@ def __init__(self, learner, data):
self.estimators_ = []
self.learn(learner, data)

def predict_storage(self, data):
p = np.zeros(data.X.shape[0])
def predict(self, X):
p = np.zeros(X.shape[0])
X = np.ascontiguousarray(X) # so that it is a no-op for individual trees
for tree in self.estimators_:
p += tree(data)
# SimpleTrees do not have preprocessors and domain conversion
# was already handled within this class so we can call tree.predict() directly
# instead of going through tree.__call__
pt = tree.predict(X)
p += pt
p /= len(self.estimators_)
return p

0 comments on commit 28c90fb

Please sign in to comment.